From d3a3b26bb3212a317da58eca42fbfc5bafc38cf8 Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:39:07 -0500 Subject: [PATCH] feat: enhance error handling and context management in MSSQL extraction process --- cmd/go_migrate/extractor.go | 248 +++++++++++++++++++++--------------- 1 file changed, 144 insertions(+), 104 deletions(-) diff --git a/cmd/go_migrate/extractor.go b/cmd/go_migrate/extractor.go index 250e5d0..bdeb3ab 100644 --- a/cmd/go_migrate/extractor.go +++ b/cmd/go_migrate/extractor.go @@ -3,6 +3,7 @@ package main import ( "context" "database/sql" + "errors" "slices" "strings" "time" @@ -42,126 +43,165 @@ func extractFromMssql( ShouldCancelJob: true, Msg: "Primary key not found in provided columns", } + select { - case chJobErrorsOut <- jobError: case <-ctx.Done(): return + case chJobErrorsOut <- jobError: } + return } - for batch := range chBatchesIn { - func() { - query := buildExtractQueryMssql(job, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive) - log.Debug("Query used to extract data from mssql: ", query) + for { + if ctx.Err() != nil { + return + } - var queryArgs []any - if batch.ShouldUseRange { - queryArgs = append(queryArgs, - sql.Named("min", batch.LowerLimit), - sql.Named("max", batch.UpperLimit), - ) - } - - queryStartTime := time.Now() - rows, err := db.QueryContext(ctx, query, queryArgs...) - if err != nil { - exError := ExtractorError{ - Batch: batch, - HasLastId: false, - Msg: err.Error(), - } - chErrorsOut <- exError + select { + case <-ctx.Done(): + return + case batch, ok := <-chBatchesIn: + if !ok { return } - defer rows.Close() - log.Debugf("Query executed in %v", time.Since(queryStartTime)) - rowsChunk := make([]UnknownRowValues, 0, chunkSize) - totalRowsExtracted := 0 - chunkStartTime := time.Now() - - for rows.Next() { - values := make([]any, len(columns)) - scanArgs := make([]any, len(columns)) - - for i := range values { - scanArgs[i] = &values[i] - } - - if err := rows.Scan(scanArgs...); err != nil { - if len(rowsChunk) == 0 { - exError := ExtractorError{ - Batch: batch, - HasLastId: false, - Msg: err.Error(), - } - chErrorsOut <- exError - return - } - - lastRow := rowsChunk[len(rowsChunk)-1] - chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) - chChunksOut <- Chunk{ - Id: uuid.New(), - BatchId: batch.Id, - Data: rowsChunk, - RetryCounter: 0, - } - return - } - - rowsChunk = append(rowsChunk, values) - totalRowsExtracted++ - - if len(rowsChunk) >= chunkSize { - chunkDuration := time.Since(chunkStartTime) - rowsPerSec := float64(chunkSize) / chunkDuration.Seconds() - log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", - len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) - chChunksOut <- Chunk{ - Id: uuid.New(), - BatchId: batch.Id, - Data: rowsChunk, - RetryCounter: 0, - } - rowsChunk = make([]UnknownRowValues, 0, chunkSize) - chunkStartTime = time.Now() - } - } - - if len(rowsChunk) > 0 { - chunkDuration := time.Since(chunkStartTime) - rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds() - log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", - len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) - chChunksOut <- Chunk{ - Id: uuid.New(), - BatchId: batch.Id, - Data: rowsChunk, - RetryCounter: 0, - } - } - - if err := rows.Err(); err != nil { - if len(rowsChunk) == 0 { - exError := ExtractorError{ - Batch: batch, - HasLastId: false, - Msg: err.Error(), - } - chErrorsOut <- exError - return - } - - lastRow := rowsChunk[len(rowsChunk)-1] - chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) + if abort := processBatch(ctx, db, job, columns, chunkSize, batch, indexPrimaryKey, chChunksOut, chErrorsOut); abort { return } - }() + } } } +func processBatch( + ctx context.Context, + db *sql.DB, + job MigrationJob, + columns []ColumnType, + chunkSize int, + batch Batch, + indexPrimaryKey int, + chChunksOut chan<- Chunk, + chErrorsOut chan<- ExtractorError, +) (abort bool) { + query := buildExtractQueryMssql(job, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive) + log.Debug("Query used to extract data from mssql: ", query) + + var queryArgs []any + if batch.ShouldUseRange { + queryArgs = append(queryArgs, + sql.Named("min", batch.LowerLimit), + sql.Named("max", batch.UpperLimit), + ) + } + + queryStartTime := time.Now() + rows, err := db.QueryContext(ctx, query, queryArgs...) + if err != nil { + select { + case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}: + case <-ctx.Done(): + return true + } + return false + } + defer rows.Close() + log.Debugf("Query executed in %v", time.Since(queryStartTime)) + + rowsChunk := make([]UnknownRowValues, 0, chunkSize) + totalRowsExtracted := 0 + chunkStartTime := time.Now() + + for rows.Next() { + values := make([]any, len(columns)) + scanArgs := make([]any, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + if err := rows.Scan(scanArgs...); err != nil { + if len(rowsChunk) == 0 { + select { + case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}: + case <-ctx.Done(): + return true + } + return false + } + + lastRow := rowsChunk[len(rowsChunk)-1] + select { + case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err): + case <-ctx.Done(): + return true + } + + select { + case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}: + case <-ctx.Done(): + return true + } + + return false + } + + rowsChunk = append(rowsChunk, values) + totalRowsExtracted++ + + if len(rowsChunk) >= chunkSize { + chunkDuration := time.Since(chunkStartTime) + rowsPerSec := float64(chunkSize) / chunkDuration.Seconds() + log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) + + select { + case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}: + case <-ctx.Done(): + return true + } + + rowsChunk = make([]UnknownRowValues, 0, chunkSize) + chunkStartTime = time.Now() + } + } + + if err := rows.Err(); err != nil { + if errors.Is(err, ctx.Err()) { + return true + } + + if len(rowsChunk) == 0 { + select { + case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}: + case <-ctx.Done(): + return true + } + return false + } + + lastRow := rowsChunk[len(rowsChunk)-1] + select { + case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err): + case <-ctx.Done(): + return true + } + return false + } + + if len(rowsChunk) > 0 { + chunkDuration := time.Since(chunkStartTime) + rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds() + log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) + select { + case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}: + case <-ctx.Done(): + return true + } + } + + return false +} + func extractFromPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error { query := buildExtractQueryPostgres(job, columns) log.Debug("Query used to extract data from postgres: ", query)