From 8afdb45318e5bad3e16ddad99cca1e489fc6ca56 Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:39:55 -0500 Subject: [PATCH] feat: implement batch processing for MSSQL extraction and transformation with range handling --- cmd/go_migrate/build-extract-query.go | 28 +++++---- cmd/go_migrate/chunk-planner.go | 91 +++++++++++++++++++++++++++ cmd/go_migrate/extractor.go | 14 ++++- cmd/go_migrate/main.go | 8 +-- cmd/go_migrate/process.go | 50 ++++++++++++--- 5 files changed, 164 insertions(+), 27 deletions(-) create mode 100644 cmd/go_migrate/chunk-planner.go diff --git a/cmd/go_migrate/build-extract-query.go b/cmd/go_migrate/build-extract-query.go index b2a8c4c..ff45665 100644 --- a/cmd/go_migrate/build-extract-query.go +++ b/cmd/go_migrate/build-extract-query.go @@ -5,30 +5,36 @@ import ( "strings" ) -func buildExtractQueryMssql(job MigrationJob, columns []ColumnType) string { - var sbColumns strings.Builder +func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool) string { + var sbQuery strings.Builder + + sbQuery.WriteString("SELECT ") if len(columns) == 0 { - sbColumns.WriteString("*") + sbQuery.WriteString("*") } else { for i, col := range columns { - sbColumns.WriteString("[") - sbColumns.WriteString(col.name) - sbColumns.WriteString("]") + fmt.Fprintf(&sbQuery, "[%s]", col.name) if col.unifiedType == "GEOMETRY" { - sbColumns.WriteString(".STAsBinary() AS [") - sbColumns.WriteString(col.name) - sbColumns.WriteString("]") + fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.name) } if i < len(columns)-1 { - sbColumns.WriteString(", ") + sbQuery.WriteString(", ") } } } - return fmt.Sprintf(`SELECT %s FROM [%s].[%s] ORDER BY [%s] ASC`, sbColumns.String(), job.Schema, job.Table, job.PrimaryKey) + fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", job.Schema, job.Table) + + if includeRange { + fmt.Fprintf(&sbQuery, " WHERE [%s] BETWEEN @minRange AND @maxRange", job.PrimaryKey) + } + + fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", job.PrimaryKey) + + return sbQuery.String() } func buildExtractQueryPostgres(job MigrationJob, columns []ColumnType) string { diff --git a/cmd/go_migrate/chunk-planner.go b/cmd/go_migrate/chunk-planner.go new file mode 100644 index 0000000..673092b --- /dev/null +++ b/cmd/go_migrate/chunk-planner.go @@ -0,0 +1,91 @@ +package main + +import ( + "context" + "database/sql" + "fmt" +) + +type BatchRange struct { + LowerLimit int + UpperLimit int + validRange bool +} + +func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int, error) { + query := ` +SELECT + SUM(p.rows) AS count +FROM sys.tables t +JOIN sys.schemas s ON t.schema_id = s.schema_id +JOIN sys.partitions p ON t.object_id = p.object_id +WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1) +GROUP BY t.name` + + var rowsCount int + err := db.QueryRowContext(ctx, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount) + if err != nil { + return 0, err + } + + return rowsCount, nil +} + +func calculateChunkRangesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int) ([]BatchRange, error) { + query := fmt.Sprintf(` +SELECT + MIN([%s]) AS lower_limit, + MAX([%s]) AS upper_limit +FROM + (SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS chunk_id FROM [%s].[%s]) AS T +GROUP BY chunk_id +ORDER BY chunk_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table) + + rows, err := db.QueryContext(ctx, query, sql.Named("batchCount", batchCount)) + if err != nil { + return nil, err + } + defer rows.Close() + + batchRanges := make([]BatchRange, 0, batchCount) + + for rows.Next() { + var br BatchRange + br.validRange = true + + if err := rows.Scan(&br.LowerLimit, &br.UpperLimit); err != nil { + return nil, err + } + + batchRanges = append(batchRanges, br) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return batchRanges, nil +} + +const estimatedRowsPerBatch = 100_000 + +func calculateBatchMetrics(ctx context.Context, db *sql.DB, job MigrationJob) ([]BatchRange, error) { + rowsCount, err := estimateTotalRowsMssql(ctx, db, job) + if err != nil { + return nil, err + } + + batchCount := 1 + if rowsCount > estimatedRowsPerBatch { + batchCount = rowsCount / estimatedRowsPerBatch + } else { + return []BatchRange{{validRange: false}}, nil + } + + chunksRange, err := calculateChunkRangesMssql(ctx, db, job, batchCount) + if err != nil { + return nil, err + } + + return chunksRange, nil +} diff --git a/cmd/go_migrate/extractor.go b/cmd/go_migrate/extractor.go index 68095d2..aef13e9 100644 --- a/cmd/go_migrate/extractor.go +++ b/cmd/go_migrate/extractor.go @@ -12,12 +12,20 @@ import ( type UnknownRowValues = []any -func extractFromMssql(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *sql.DB, out chan<- []UnknownRowValues) error { - query := buildExtractQueryMssql(job, columns) +func extractFromMssql(ctx context.Context, db *sql.DB, job MigrationJob, columns []ColumnType, chunkSize int, batchRange BatchRange, out chan<- []UnknownRowValues) error { + query := buildExtractQueryMssql(job, columns, batchRange.validRange) log.Debug("Query used to extract data from mssql: ", query) + var queryArgs []any + if batchRange.validRange { + queryArgs = append(queryArgs, + sql.Named("minRange", batchRange.LowerLimit), + sql.Named("maxRange", batchRange.UpperLimit), + ) + } + queryStartTime := time.Now() - rows, err := db.QueryContext(ctx, query) + rows, err := db.QueryContext(ctx, query, queryArgs...) if err != nil { return err } diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index 48d3c32..757888b 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -21,10 +21,10 @@ var migrationJobs []MigrationJob = []MigrationJob{ } const ( - NumExtractors int = 1 - NumLoaders int = 4 - ChunkSize int = 50000 - QueueSize int = 10 + NumExtractors int = 4 + NumLoaders int = 4 + ChunkSize int = 50000 + QueueSize int = 10 ) func main() { diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index e6af049..30bd8f0 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -24,24 +24,56 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration logColumnTypes(sourceColTypes, "Source col types") logColumnTypes(targetColTypes, "Target col types") - chRowsExtract := make(chan []UnknownRowValues, QueueSize) - chRowsTransform := make(chan []UnknownRowValues) mssqlCtx := context.Background() + batchRanges, err := calculateBatchMetrics(mssqlCtx, sourceDb, job) + if err != nil { + log.Error("Unexpected error calculating batch ranges: ", err) + } + + chBatchRanges := make(chan BatchRange, len(batchRanges)) + + maxExtractors := min(NumExtractors, len(batchRanges)) + chRowsExtract := make(chan []UnknownRowValues, QueueSize) + var wgMssqlExtractors sync.WaitGroup + + log.Infof("Starting %d MSSQL extractors...", maxExtractors) + extractStartTime := time.Now() + for range maxExtractors { + wgMssqlExtractors.Go(func() { + for br := range chBatchRanges { + if err := extractFromMssql(mssqlCtx, sourceDb, job, sourceColTypes, ChunkSize, br, chRowsExtract); err != nil { + log.Error("Unexpected error extracting data from mssql: ", err) + } + } + }) + } go func() { - log.Info("Starting extraction from MSSQL...") - extractStartTime := time.Now() - if err := extractFromMssql(mssqlCtx, job, sourceColTypes, ChunkSize, sourceDb, chRowsExtract); err != nil { - log.Error("Unexpected error extracting data from mssql: ", err) + for _, br := range batchRanges { + chBatchRanges <- br } + close(chBatchRanges) + }() + + go func() { + wgMssqlExtractors.Wait() close(chRowsExtract) log.Infof("Extraction completed in %v", time.Since(extractStartTime)) }() + chRowsTransform := make(chan []UnknownRowValues, QueueSize) + var wgMssqlTransformers sync.WaitGroup + + log.Infof("Starting %d MSSQL transformers...", maxExtractors) + transformStartTime := time.Now() + for range maxExtractors { + wgMssqlTransformers.Go(func() { + transformRowsMssql(sourceColTypes, chRowsExtract, chRowsTransform) + }) + } + go func() { - log.Info("Starting transformation of rows...") - transformStartTime := time.Now() - transformRowsMssql(sourceColTypes, chRowsExtract, chRowsTransform) + wgMssqlTransformers.Wait() close(chRowsTransform) log.Infof("Transformation completed in %v", time.Since(transformStartTime)) }()