diff --git a/cmd/go_migrate/batch-generator.go b/cmd/go_migrate/batch-generator.go new file mode 100644 index 0000000..78311a9 --- /dev/null +++ b/cmd/go_migrate/batch-generator.go @@ -0,0 +1,110 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" +) + +type Batch struct { + Id uuid.UUID + ParentId uuid.UUID + LowerLimit int64 + UpperLimit int64 + IsLowerLimitInclusive bool + ShouldUseRange bool + RetryCounter int +} + +func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int64, error) { + query := ` +SELECT + SUM(p.rows) AS count +FROM sys.tables t +JOIN sys.schemas s ON t.schema_id = s.schema_id +JOIN sys.partitions p ON t.object_id = p.object_id +WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1) +GROUP BY t.name` + + ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20) + defer cancel() + + var rowsCount int64 + err := db.QueryRowContext(ctxTimeout, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount) + if err != nil { + return 0, err + } + + return rowsCount, nil +} + +func calculateBatchesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int64) ([]Batch, error) { + query := fmt.Sprintf(` +SELECT + MIN([%s]) AS lower_limit, + MAX([%s]) AS upper_limit +FROM + (SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T +GROUP BY batch_id +ORDER BY batch_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table) + + ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20) + defer cancel() + + rows, err := db.QueryContext(ctxTimeout, query, sql.Named("batchCount", batchCount)) + if err != nil { + return nil, err + } + defer rows.Close() + + batches := make([]Batch, 0, batchCount) + + for rows.Next() { + batch := Batch{ + Id: uuid.New(), + ShouldUseRange: true, + RetryCounter: 0, + IsLowerLimitInclusive: true, + } + + if err := rows.Scan(&batch.LowerLimit, &batch.UpperLimit); err != nil { + return nil, err + } + + batches = append(batches, batch) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return batches, nil +} + +func batchGeneratorMssql(ctx context.Context, db *sql.DB, job MigrationJob) ([]Batch, error) { + rowsCount, err := estimateTotalRowsMssql(ctx, db, job) + if err != nil { + return nil, err + } + + var batchCount int64 = 1 + if rowsCount > RowsPerBatch { + batchCount = rowsCount / RowsPerBatch + } else { + return []Batch{{ + Id: uuid.New(), + ShouldUseRange: false, + RetryCounter: 0, + }}, nil + } + + batches, err := calculateBatchesMssql(ctx, db, job, batchCount) + if err != nil { + return nil, err + } + + return batches, nil +} diff --git a/cmd/go_migrate/build-extract-query.go b/cmd/go_migrate/build-extract-query.go index ff45665..a43f09e 100644 --- a/cmd/go_migrate/build-extract-query.go +++ b/cmd/go_migrate/build-extract-query.go @@ -5,7 +5,7 @@ import ( "strings" ) -func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool) string { +func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool, isMinInclusive bool) string { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") @@ -29,7 +29,14 @@ func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", job.Schema, job.Table) if includeRange { - fmt.Fprintf(&sbQuery, " WHERE [%s] BETWEEN @minRange AND @maxRange", job.PrimaryKey) + fmt.Fprintf(&sbQuery, " WHERE [%s]", job.PrimaryKey) + if isMinInclusive { + sbQuery.WriteString(" >=") + } else { + sbQuery.WriteString(" >") + } + + fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", job.PrimaryKey) } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", job.PrimaryKey) diff --git a/cmd/go_migrate/chunk-planner.go b/cmd/go_migrate/chunk-planner.go deleted file mode 100644 index 673092b..0000000 --- a/cmd/go_migrate/chunk-planner.go +++ /dev/null @@ -1,91 +0,0 @@ -package main - -import ( - "context" - "database/sql" - "fmt" -) - -type BatchRange struct { - LowerLimit int - UpperLimit int - validRange bool -} - -func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int, error) { - query := ` -SELECT - SUM(p.rows) AS count -FROM sys.tables t -JOIN sys.schemas s ON t.schema_id = s.schema_id -JOIN sys.partitions p ON t.object_id = p.object_id -WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1) -GROUP BY t.name` - - var rowsCount int - err := db.QueryRowContext(ctx, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount) - if err != nil { - return 0, err - } - - return rowsCount, nil -} - -func calculateChunkRangesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int) ([]BatchRange, error) { - query := fmt.Sprintf(` -SELECT - MIN([%s]) AS lower_limit, - MAX([%s]) AS upper_limit -FROM - (SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS chunk_id FROM [%s].[%s]) AS T -GROUP BY chunk_id -ORDER BY chunk_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table) - - rows, err := db.QueryContext(ctx, query, sql.Named("batchCount", batchCount)) - if err != nil { - return nil, err - } - defer rows.Close() - - batchRanges := make([]BatchRange, 0, batchCount) - - for rows.Next() { - var br BatchRange - br.validRange = true - - if err := rows.Scan(&br.LowerLimit, &br.UpperLimit); err != nil { - return nil, err - } - - batchRanges = append(batchRanges, br) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return batchRanges, nil -} - -const estimatedRowsPerBatch = 100_000 - -func calculateBatchMetrics(ctx context.Context, db *sql.DB, job MigrationJob) ([]BatchRange, error) { - rowsCount, err := estimateTotalRowsMssql(ctx, db, job) - if err != nil { - return nil, err - } - - batchCount := 1 - if rowsCount > estimatedRowsPerBatch { - batchCount = rowsCount / estimatedRowsPerBatch - } else { - return []BatchRange{{validRange: false}}, nil - } - - chunksRange, err := calculateChunkRangesMssql(ctx, db, job, batchCount) - if err != nil { - return nil, err - } - - return chunksRange, nil -} diff --git a/cmd/go_migrate/extractor-error-handler.go b/cmd/go_migrate/extractor-error-handler.go new file mode 100644 index 0000000..5577cab --- /dev/null +++ b/cmd/go_migrate/extractor-error-handler.go @@ -0,0 +1,66 @@ +package main + +import ( + "fmt" + + "github.com/google/uuid" +) + +type ExtractorError struct { + Batch + LastId int64 + HasLastId bool + Msg string +} + +func (e *ExtractorError) Error() string { + return e.Msg +} + +const maxRetryAttempts = 3 + +func extractorErrorHandler(chErrorsIn <-chan ExtractorError, chBatchesOut chan<- Batch, chGlobalErrorsOut chan<- error) { + for err := range chErrorsIn { + if err.RetryCounter >= maxRetryAttempts { + chGlobalErrorsOut <- fmt.Errorf("batch %v reached max retries (%d): %s", err.Id, maxRetryAttempts, err.Msg) + continue + } + + newBatch := err.Batch + newBatch.RetryCounter++ + + if err.HasLastId { + newBatch.ParentId = err.Id + newBatch.Id = uuid.New() + newBatch.LowerLimit = err.LastId + newBatch.IsLowerLimitInclusive = false + } + + chBatchesOut <- newBatch + } +} + +func ExtractorErrorFromLastRowMssql(lastRow UnknownRowValues, indexPrimaryKey int, batch *Batch, previousError error) ExtractorError { + lastIdRawValue := lastRow[indexPrimaryKey] + + lastId, ok := ToInt64(lastIdRawValue) + if !ok { + currentBatch := *batch + currentBatch.RetryCounter = maxRetryAttempts + exError := ExtractorError{ + Batch: currentBatch, + HasLastId: true, + Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()), + } + return exError + + } + + exError := ExtractorError{ + Batch: *batch, + HasLastId: true, + LastId: lastId, + Msg: previousError.Error(), + } + return exError +} diff --git a/cmd/go_migrate/extractor.go b/cmd/go_migrate/extractor.go index aef13e9..f971db4 100644 --- a/cmd/go_migrate/extractor.go +++ b/cmd/go_migrate/extractor.go @@ -3,6 +3,8 @@ package main import ( "context" "database/sql" + "slices" + "strings" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -12,67 +14,126 @@ import ( type UnknownRowValues = []any -func extractFromMssql(ctx context.Context, db *sql.DB, job MigrationJob, columns []ColumnType, chunkSize int, batchRange BatchRange, out chan<- []UnknownRowValues) error { - query := buildExtractQueryMssql(job, columns, batchRange.validRange) - log.Debug("Query used to extract data from mssql: ", query) +func extractFromMssql( + ctx context.Context, + db *sql.DB, + job MigrationJob, + columns []ColumnType, + chunkSize int, + chBatchesIn <-chan Batch, + chChunksOut chan<- []UnknownRowValues, + chErrorsOut chan<- ExtractorError, +) { + indexPrimaryKey := slices.IndexFunc(columns, func(col ColumnType) bool { + return strings.EqualFold(col.name, job.PrimaryKey) + }) - var queryArgs []any - if batchRange.validRange { - queryArgs = append(queryArgs, - sql.Named("minRange", batchRange.LowerLimit), - sql.Named("maxRange", batchRange.UpperLimit), - ) - } - - queryStartTime := time.Now() - rows, err := db.QueryContext(ctx, query, queryArgs...) - if err != nil { - return err - } - defer rows.Close() - log.Debugf("Query executed in %v", time.Since(queryStartTime)) - - rowsChunk := make([]UnknownRowValues, 0, chunkSize) - totalRowsExtracted := 0 - chunkCount := 0 - chunkStartTime := time.Now() - - for rows.Next() { - values := make([]any, len(columns)) - scanArgs := make([]any, len(columns)) - - for i := range values { - scanArgs[i] = &values[i] - } - - if err := rows.Scan(scanArgs...); err != nil { - return err - } - - rowsChunk = append(rowsChunk, values) - totalRowsExtracted++ - - if len(rowsChunk) >= chunkSize { - chunkCount++ - chunkDuration := time.Since(chunkStartTime) - rowsPerSec := float64(chunkSize) / chunkDuration.Seconds() - log.Infof("Extracted chunk #%d: %d rows in %v (%.0f rows/sec) - Total: %d rows", chunkCount, len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) - out <- rowsChunk - rowsChunk = make([]UnknownRowValues, 0, chunkSize) - chunkStartTime = time.Now() + if indexPrimaryKey == -1 { + exError := ExtractorError{ + Batch: Batch{ + RetryCounter: maxRetryAttempts, + }, + HasLastId: false, + Msg: "Primary key not found in columns provided", } + chErrorsOut <- exError + return } - if len(rowsChunk) > 0 { - chunkCount++ - chunkDuration := time.Since(chunkStartTime) - rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds() - log.Infof("Extracted final chunk #%d: %d rows in %v (%.0f rows/sec) - Total: %d rows", - chunkCount, len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) - out <- rowsChunk - } + for batch := range chBatchesIn { + func() { + query := buildExtractQueryMssql(job, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive) + log.Debug("Query used to extract data from mssql: ", query) - return rows.Err() + var queryArgs []any + if batch.ShouldUseRange { + queryArgs = append(queryArgs, + sql.Named("min", batch.LowerLimit), + sql.Named("max", batch.UpperLimit), + ) + } + + queryStartTime := time.Now() + rows, err := db.QueryContext(ctx, query, queryArgs...) + if err != nil { + exError := ExtractorError{ + Batch: batch, + HasLastId: false, + Msg: err.Error(), + } + chErrorsOut <- exError + return + } + defer rows.Close() + log.Debugf("Query executed in %v", time.Since(queryStartTime)) + + rowsChunk := make([]UnknownRowValues, 0, chunkSize) + totalRowsExtracted := 0 + chunkStartTime := time.Now() + + for rows.Next() { + values := make([]any, len(columns)) + scanArgs := make([]any, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + if err := rows.Scan(scanArgs...); err != nil { + if len(rowsChunk) == 0 { + exError := ExtractorError{ + Batch: batch, + HasLastId: false, + Msg: err.Error(), + } + chErrorsOut <- exError + return + } + + lastRow := rowsChunk[len(rowsChunk)-1] + chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) + return + } + + rowsChunk = append(rowsChunk, values) + totalRowsExtracted++ + + if len(rowsChunk) >= chunkSize { + chunkDuration := time.Since(chunkStartTime) + rowsPerSec := float64(chunkSize) / chunkDuration.Seconds() + log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", + len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) + chChunksOut <- rowsChunk + rowsChunk = make([]UnknownRowValues, 0, chunkSize) + chunkStartTime = time.Now() + } + } + + if len(rowsChunk) > 0 { + chunkDuration := time.Since(chunkStartTime) + rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds() + log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", + len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted) + chChunksOut <- rowsChunk + } + + if err := rows.Err(); err != nil { + if len(rowsChunk) == 0 { + exError := ExtractorError{ + Batch: batch, + HasLastId: false, + Msg: err.Error(), + } + chErrorsOut <- exError + return + } + + lastRow := rowsChunk[len(rowsChunk)-1] + chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) + return + } + }() + } } func extractFromPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error { diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index 0ab418f..a33b1dd 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -26,10 +26,12 @@ var migrationJobs []MigrationJob = []MigrationJob{ } const ( - NumExtractors int = 4 - NumLoaders int = 8 - ChunkSize int = 25000 - QueueSize int = 8 + NumExtractors int = 4 + NumLoaders int = 8 + ChunkSize int = 25000 + QueueSize int = 8 + ChunksPerBatch int = 16 + RowsPerBatch int64 = int64(ChunkSize * ChunksPerBatch) ) func main() { diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index 7ec0df1..99e96a0 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -25,39 +25,43 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration logColumnTypes(targetColTypes, "Target col types") mssqlCtx := context.Background() - batchRanges, err := calculateBatchMetrics(mssqlCtx, sourceDb, job) + batches, err := batchGeneratorMssql(mssqlCtx, sourceDb, job) if err != nil { log.Error("Unexpected error calculating batch ranges: ", err) } - chBatchRanges := make(chan BatchRange, len(batchRanges)) + chGlobalErrors := make(chan error) + defer close(chGlobalErrors) - maxExtractors := min(NumExtractors, len(batchRanges)) - chRowsExtract := make(chan []UnknownRowValues, QueueSize) + chBatches := make(chan Batch, len(batches)) + chChunks := make(chan []UnknownRowValues, QueueSize) + chExtractorErrors := make(chan ExtractorError, len(batches)) + maxExtractors := min(NumExtractors, len(batches)) var wgMssqlExtractors sync.WaitGroup log.Infof("Starting %d MSSQL extractors...", maxExtractors) extractStartTime := time.Now() for range maxExtractors { wgMssqlExtractors.Go(func() { - for br := range chBatchRanges { - if err := extractFromMssql(mssqlCtx, sourceDb, job, sourceColTypes, ChunkSize, br, chRowsExtract); err != nil { - log.Error("Unexpected error extracting data from mssql: ", err) - } - } + extractFromMssql(mssqlCtx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunks, chExtractorErrors) }) } go func() { - for _, br := range batchRanges { - chBatchRanges <- br + for _, br := range batches { + chBatches <- br } - close(chBatchRanges) + close(chBatches) + close(chExtractorErrors) + }() + + go func() { + extractorErrorHandler(chExtractorErrors, chBatches, chGlobalErrors) }() go func() { wgMssqlExtractors.Wait() - close(chRowsExtract) + close(chChunks) log.Infof("Extraction completed in %v", time.Since(extractStartTime)) }() @@ -68,7 +72,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration transformStartTime := time.Now() for range maxExtractors { wgMssqlTransformers.Go(func() { - transformRowsMssql(sourceColTypes, chRowsExtract, chRowsTransform) + transformRowsMssql(sourceColTypes, chChunks, chRowsTransform) }) } diff --git a/cmd/go_migrate/transformer.go b/cmd/go_migrate/transformer.go index fd68e02..ceb56d0 100644 --- a/cmd/go_migrate/transformer.go +++ b/cmd/go_migrate/transformer.go @@ -43,3 +43,20 @@ func transformRowsMssql(columns []ColumnType, in <-chan []UnknownRowValues, out out <- rows } } + +func ToInt64(v any) (int64, bool) { + switch t := v.(type) { + case int: + return int64(t), true + case int8: + return int64(t), true + case int16: + return int64(t), true + case int32: + return int64(t), true + case int64: + return int64(t), true + default: + return 0, false + } +}