diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index a33b1dd..f40b436 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "time" log "github.com/sirupsen/logrus" @@ -37,6 +38,10 @@ const ( func main() { configureLog() startTime := time.Now() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + log.Info("=== Starting migration ===") log.Infof("Number of loaders: %d, Chunk size: %d", NumLoaders, ChunkSize) @@ -50,7 +55,7 @@ func main() { for _, job := range migrationJobs { log.Infof(">>> Processing job: %s.%s <<<", job.Schema, job.Table) - processMigrationJob(sourceDb, targetDb, job) + processMigrationJob(ctx, sourceDb, targetDb, job) } totalDuration := time.Since(startTime) diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index 636a86e..aeb267f 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -12,7 +12,12 @@ import ( log "github.com/sirupsen/logrus" ) -func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job MigrationJob) { +func processMigrationJob( + ctx context.Context, + sourceDb *sql.DB, + targetDb *pgxpool.Pool, + job MigrationJob, +) { jobStartTime := time.Now() log.Infof("Starting migration job: %s.%s [PK: %s]", job.Schema, job.Table, job.PrimaryKey) @@ -24,10 +29,10 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration logColumnTypes(sourceColTypes, "Source col types") logColumnTypes(targetColTypes, "Target col types") - ctx, cancel := context.WithCancel(context.Background()) + jobCtx, cancel := context.WithCancel(ctx) defer cancel() - batches, err := batchGeneratorMssql(ctx, sourceDb, job) + batches, err := batchGeneratorMssql(jobCtx, sourceDb, job) if err != nil { log.Error("Unexpected error calculating batch ranges: ", err) } @@ -46,13 +51,13 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration var wgLoaders sync.WaitGroup go func() { - if err := jobErrorHandler(ctx, chJobErrors); err != nil { + if err := jobErrorHandler(jobCtx, chJobErrors); err != nil { cancel() } }() - go extractorErrorHandler(ctx, chExtractorErrors, chBatches, chJobErrors, &wgActiveBatches) - go loaderErrorHandler(ctx, chLoadersErrors, chChunksTransformed, chJobErrors, &wgActiveChunks) + go extractorErrorHandler(jobCtx, chExtractorErrors, chBatches, chJobErrors, &wgActiveBatches) + go loaderErrorHandler(jobCtx, chLoadersErrors, chChunksTransformed, chJobErrors, &wgActiveChunks) maxExtractors := min(NumExtractors, len(batches)) log.Infof("Starting %d extractors...", maxExtractors) @@ -60,7 +65,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration for range maxExtractors { wgExtractors.Go(func() { - extractFromMssql(ctx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunksRaw, chExtractorErrors, chJobErrors, &wgActiveBatches) + extractFromMssql(jobCtx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunksRaw, chExtractorErrors, chJobErrors, &wgActiveBatches) }) } @@ -76,7 +81,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration for range maxExtractors { wgTransformers.Go(func() { - transformRowsMssql(ctx, sourceColTypes, chChunksRaw, chChunksTransformed, chJobErrors, &wgActiveChunks) + transformRowsMssql(jobCtx, sourceColTypes, chChunksRaw, chChunksTransformed, chJobErrors, &wgActiveChunks) }) } @@ -85,7 +90,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration for range NumLoaders { wgLoaders.Go(func() { - loadRowsPostgres(ctx, targetDb, job, targetColTypes, chChunksTransformed, chLoadersErrors, chJobErrors, &wgActiveChunks) + loadRowsPostgres(jobCtx, targetDb, job, targetColTypes, chChunksTransformed, chLoadersErrors, chJobErrors, &wgActiveChunks) }) } @@ -111,7 +116,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration cancel() }() - <-ctx.Done() + <-jobCtx.Done() log.Infof("Migration job completed. Total time: %v", time.Since(jobStartTime)) }