diff --git a/cmd/go_migrate/extractor-error-handler.go b/cmd/go_migrate/extractor-error-handler.go index e48af32..a223631 100644 --- a/cmd/go_migrate/extractor-error-handler.go +++ b/cmd/go_migrate/extractor-error-handler.go @@ -81,20 +81,18 @@ func ExtractorErrorFromLastRowMssql(lastRow UnknownRowValues, indexPrimaryKey in if !ok { currentBatch := *batch currentBatch.RetryCounter = maxRetryAttempts - exError := ExtractorError{ + return ExtractorError{ Batch: currentBatch, HasLastId: true, Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()), } - return exError } - exError := ExtractorError{ + return ExtractorError{ Batch: *batch, HasLastId: true, LastId: lastId, Msg: previousError.Error(), } - return exError } diff --git a/cmd/go_migrate/loader-error-handler.go b/cmd/go_migrate/loader-error-handler.go new file mode 100644 index 0000000..c1ac8c1 --- /dev/null +++ b/cmd/go_migrate/loader-error-handler.go @@ -0,0 +1,61 @@ +package main + +import ( + "context" + "fmt" +) + +type LoaderError struct { + Chunk + Msg string +} + +func (e *LoaderError) Error() string { + return e.Msg +} + +func loaderErrorHandler( + ctx context.Context, + chErrorsIn <-chan LoaderError, + chChunksOut chan<- Chunk, + chJobErrorsOut chan<- JobError, +) { + for { + if ctx.Err() != nil { + return + } + + select { + case <-ctx.Done(): + return + + case err, ok := <-chErrorsIn: + if !ok { + return + } + + if err.RetryCounter >= maxRetryAttempts { + jobError := JobError{ + ShouldCancelJob: false, + Msg: fmt.Sprintf("chunk %v reached max retries (%d)", err.Id, maxRetryAttempts), + Prev: &err, + } + + select { + case chJobErrorsOut <- jobError: + case <-ctx.Done(): + return + } + continue + } + + err.RetryCounter++ + + select { + case chChunksOut <- err.Chunk: + case <-ctx.Done(): + return + } + } + } +} diff --git a/cmd/go_migrate/loader.go b/cmd/go_migrate/loader.go index 58a919d..a677bc1 100644 --- a/cmd/go_migrate/loader.go +++ b/cmd/go_migrate/loader.go @@ -12,39 +12,70 @@ import ( log "github.com/sirupsen/logrus" ) -func loadRowsPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, db *pgxpool.Pool, chChunksIn <-chan Chunk) error { - chunkCount := 0 - totalRowsLoaded := 0 +func loadRowsPostgres( + ctx context.Context, + db *pgxpool.Pool, + job MigrationJob, + columns []ColumnType, + chChunksIn <-chan Chunk, + chErrorsOut chan<- LoaderError, +) { + tableId := pgx.Identifier{job.Schema, job.Table} + colNames := Map(columns, func(col ColumnType) string { + return col.name + }) - for chunk := range chChunksIn { - chunkStartTime := time.Now() - identifier := pgx.Identifier{job.Schema, job.Table} - colNames := Map(columns, func(col ColumnType) string { - return col.name - }) - - copyStartTime := time.Now() - _, err := db.CopyFrom( - ctx, - identifier, - colNames, - pgx.CopyFromRows(chunk.Data), - ) - - if err != nil { - return err + for { + if ctx.Err() != nil { + return } - chunkCount++ - totalRowsLoaded += len(chunk.Data) - copyDuration := time.Since(copyStartTime) - chunkDuration := time.Since(chunkStartTime) - rowsPerSec := float64(len(chunk.Data)) / chunkDuration.Seconds() + select { + case <-ctx.Done(): + return + case chunk, ok := <-chChunksIn: + if !ok { + return + } - log.Infof("Loaded chunk #%d: %d rows in %v (copy: %v, %.0f rows/sec) - Total: %d rows", chunkCount, len(chunk.Data), chunkDuration, copyDuration, rowsPerSec, totalRowsLoaded) + if abort := loadChunkPostgres(ctx, db, tableId, colNames, chunk, chErrorsOut); abort { + return + } + } + } +} + +func loadChunkPostgres( + ctx context.Context, + db *pgxpool.Pool, + identifier pgx.Identifier, + colNames []string, + chunk Chunk, + chErrorsOut chan<- LoaderError, +) (abort bool) { + chunkStartTime := time.Now() + _, err := db.CopyFrom( + ctx, + identifier, + colNames, + pgx.CopyFromRows(chunk.Data), + ) + + if err != nil { + select { + case chErrorsOut <- LoaderError{Chunk: chunk, Msg: err.Error()}: + case <-ctx.Done(): + return true + } + return false } - return nil + chunkDuration := time.Since(chunkStartTime) + rowsPerSec := float64(len(chunk.Data)) / chunkDuration.Seconds() + + log.Infof("Loaded chunk: %d rows in %v (%.0f rows/sec)", len(chunk.Data), chunkDuration, rowsPerSec) + + return false } func loadRowsMssql(ctx context.Context, job MigrationJob, columns []ColumnType, db *sql.DB, in <-chan []UnknownRowValues) error { diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index db25289..cc6b457 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -50,7 +50,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration extractorErrorHandler(ctx, chExtractorErrors, chBatches, chJobErrors) }() - chChunks := make(chan Chunk, QueueSize) + chChunksRaw := make(chan Chunk, QueueSize) maxExtractors := min(NumExtractors, len(batches)) var wgMssqlExtractors sync.WaitGroup @@ -58,7 +58,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration extractStartTime := time.Now() for range maxExtractors { wgMssqlExtractors.Go(func() { - extractFromMssql(ctx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunks, chExtractorErrors, chJobErrors) + extractFromMssql(ctx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunksRaw, chExtractorErrors, chJobErrors) }) } @@ -72,41 +72,45 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration go func() { wgMssqlExtractors.Wait() - close(chChunks) + close(chChunksRaw) log.Infof("Extraction completed in %v", time.Since(extractStartTime)) }() - chChunksTransform := make(chan Chunk, QueueSize) + chChunksTransformed := make(chan Chunk, QueueSize) var wgMssqlTransformers sync.WaitGroup log.Infof("Starting %d MSSQL transformers...", maxExtractors) transformStartTime := time.Now() for range maxExtractors { wgMssqlTransformers.Go(func() { - transformRowsMssql(ctx, sourceColTypes, chChunks, chChunksTransform, chJobErrors) + transformRowsMssql(ctx, sourceColTypes, chChunksRaw, chChunksTransformed, chJobErrors) }) } go func() { wgMssqlTransformers.Wait() - close(chChunksTransform) + close(chChunksTransformed) log.Infof("Transformation completed in %v", time.Since(transformStartTime)) }() var wgPostgresLoaders sync.WaitGroup + chLoadersErrors := make(chan LoaderError) + + go func() { + loaderErrorHandler(ctx, chLoadersErrors, chChunksTransformed, chJobErrors) + }() log.Infof("Starting %d PostgreSQL loader(s)...", NumLoaders) loaderStartTime := time.Now() for range NumLoaders { wgPostgresLoaders.Go(func() { - if err := loadRowsPostgres(ctx, job, targetColTypes, targetDb, chChunksTransform); err != nil { - log.Error("Unexpected error loading data into postgres: ", err) - } + loadRowsPostgres(ctx, targetDb, job, targetColTypes, chChunksTransformed, chLoadersErrors) }) } wgPostgresLoaders.Wait() + close(chLoadersErrors) log.Infof("Loading completed in %v", time.Since(loaderStartTime)) totalDuration := time.Since(jobStartTime)