From 837fdc7abbdfa56a3fd4d1e8ddddaee9ee3286dd Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Tue, 12 May 2026 11:01:50 -0500 Subject: [PATCH] refactor: add validation feature to compare row counts between source and target databases --- cmd/go_migrate/main.go | 7 ++ cmd/go_migrate/validate.go | 192 +++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 cmd/go_migrate/validate.go diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index 84695d6..fcd7679 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -22,6 +22,7 @@ func main() { checkExpiry() configPath := flag.String("config", "", "path to migration config file") + validate := flag.Bool("validate", false, "count rows in source and target per job and compare") flag.Parse() if flag.NArg() > 1 { @@ -78,6 +79,12 @@ func main() { defer sourceDb.Close() defer targetDb.Close() + if *validate { + validationResults := validateJobs(ctx, sourceDb, targetDb, migrationConfig.Jobs, migrationConfig.MaxParallelWorkers) + printValidationReport(validationResults) + return + } + results := processMigrationJobs(ctx, sourceDb, targetDb, migrationConfig.Jobs, migrationConfig.MaxParallelWorkers) log.Info("=== RESUMEN DE MIGRACIÓN ===") diff --git a/cmd/go_migrate/validate.go b/cmd/go_migrate/validate.go new file mode 100644 index 0000000..655ff15 --- /dev/null +++ b/cmd/go_migrate/validate.go @@ -0,0 +1,192 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" + log "github.com/sirupsen/logrus" +) + +type ValidationResult struct { + JobName string + SourceTable string + TargetTable string + SourceCount int64 + TargetCount int64 + Match bool + Error error +} + +func countSourceRows(ctx context.Context, db dbwrapper.DbWrapper, job config.Job) (int64, error) { + schema := job.SourceTable.Schema + table := job.SourceTable.Table + + hasRange := job.Range.Min != nil || job.Range.Max != nil + + var ( + query string + args []any + ) + + if hasRange && job.SourceTable.PrimaryKey != "" { + query = fmt.Sprintf("SELECT COUNT_BIG(*) FROM [%s].[%s] WHERE 1=1", schema, table) + if job.Range.Min != nil { + op := ">" + if job.Range.IsMinInclusive { + op = ">=" + } + query += fmt.Sprintf(" AND [%s] %s @min", job.SourceTable.PrimaryKey, op) + args = append(args, sql.Named("min", *job.Range.Min)) + } + if job.Range.Max != nil { + op := "<" + if job.Range.IsMaxInclusive { + op = "<=" + } + query += fmt.Sprintf(" AND [%s] %s @max", job.SourceTable.PrimaryKey, op) + args = append(args, sql.Named("max", *job.Range.Max)) + } + } else { + query = fmt.Sprintf("SELECT COUNT_BIG(*) FROM [%s].[%s]", schema, table) + } + + var count int64 + if err := db.QueryRow(ctx, query, args...).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +func countTargetRows(ctx context.Context, db dbwrapper.DbWrapper, job config.Job) (int64, error) { + schema := job.TargetTable.Schema + table := job.TargetTable.Table + query := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, schema, table) + + var count int64 + if err := db.QueryRow(ctx, query).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +func validateJob(ctx context.Context, sourceDb, targetDb dbwrapper.DbWrapper, job config.Job) ValidationResult { + result := ValidationResult{ + JobName: job.Name, + SourceTable: fmt.Sprintf("[%s].[%s]", job.SourceTable.Schema, job.SourceTable.Table), + TargetTable: fmt.Sprintf(`"%s"."%s"`, job.TargetTable.Schema, job.TargetTable.Table), + } + + var ( + sourceErr, targetErr error + wg sync.WaitGroup + ) + + wg.Add(2) + go func() { + defer wg.Done() + result.SourceCount, sourceErr = countSourceRows(ctx, sourceDb, job) + }() + go func() { + defer wg.Done() + result.TargetCount, targetErr = countTargetRows(ctx, targetDb, job) + }() + wg.Wait() + + if sourceErr != nil { + result.Error = fmt.Errorf("source count failed: %w", sourceErr) + return result + } + if targetErr != nil { + result.Error = fmt.Errorf("target count failed: %w", targetErr) + return result + } + + result.Match = result.SourceCount == result.TargetCount + return result +} + +func validateJobs( + ctx context.Context, + sourceDb dbwrapper.DbWrapper, + targetDb dbwrapper.DbWrapper, + jobs []config.Job, + maxParallelWorkers int, +) []ValidationResult { + if len(jobs) == 0 { + return nil + } + if maxParallelWorkers <= 0 { + maxParallelWorkers = 1 + } + if maxParallelWorkers > len(jobs) { + maxParallelWorkers = len(jobs) + } + + chJobs := make(chan config.Job, len(jobs)) + var mu sync.Mutex + var results []ValidationResult + var wg sync.WaitGroup + + for range maxParallelWorkers { + wg.Go(func() { + for job := range chJobs { + res := validateJob(ctx, sourceDb, targetDb, job) + mu.Lock() + results = append(results, res) + mu.Unlock() + } + }) + } + + for _, job := range jobs { + chJobs <- job + } + close(chJobs) + wg.Wait() + + return results +} + +func printValidationReport(results []ValidationResult) { + log.Info("=== VALIDATION REPORT ===") + + var totalMatch, totalMismatch, totalErrors int + + for _, r := range results { + if r.Error != nil { + log.Errorf("[%s] ERROR: %v", r.JobName, r.Error) + totalErrors++ + continue + } + + if !r.Match { + totalMismatch++ + diff := r.TargetCount - r.SourceCount + var diffStr string + if diff > 0 { + diffStr = fmt.Sprintf(" (target has %d extra rows)", diff) + } else { + diffStr = fmt.Sprintf(" (target is missing %d rows)", -diff) + } + log.Warnf("[%s] MISMATCH | Source %s: %d | Target %s: %d%s", + r.JobName, + r.SourceTable, r.SourceCount, + r.TargetTable, r.TargetCount, + diffStr, + ) + } else { + totalMatch++ + log.Infof("[%s] OK | Source %s: %d | Target %s: %d", + r.JobName, + r.SourceTable, r.SourceCount, + r.TargetTable, r.TargetCount, + ) + } + } + + log.Infof("=== Validation complete: %d OK, %d mismatches, %d errors ===", totalMatch, totalMismatch, totalErrors) +}