feat: implement batch processing for MSSQL with improved structure and logging

This commit is contained in:
2026-04-08 17:35:12 -05:00
parent 51d83661a4
commit d32d4df6e4
8 changed files with 343 additions and 167 deletions

View File

@@ -0,0 +1,110 @@
package main
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
)
type Batch struct {
Id uuid.UUID
ParentId uuid.UUID
LowerLimit int64
UpperLimit int64
IsLowerLimitInclusive bool
ShouldUseRange bool
RetryCounter int
}
func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int64, error) {
query := `
SELECT
SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
var rowsCount int64
err := db.QueryRowContext(ctxTimeout, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func calculateBatchesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int64) ([]Batch, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM
(SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T
GROUP BY batch_id
ORDER BY batch_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
rows, err := db.QueryContext(ctxTimeout, query, sql.Named("batchCount", batchCount))
if err != nil {
return nil, err
}
defer rows.Close()
batches := make([]Batch, 0, batchCount)
for rows.Next() {
batch := Batch{
Id: uuid.New(),
ShouldUseRange: true,
RetryCounter: 0,
IsLowerLimitInclusive: true,
}
if err := rows.Scan(&batch.LowerLimit, &batch.UpperLimit); err != nil {
return nil, err
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return batches, nil
}
func batchGeneratorMssql(ctx context.Context, db *sql.DB, job MigrationJob) ([]Batch, error) {
rowsCount, err := estimateTotalRowsMssql(ctx, db, job)
if err != nil {
return nil, err
}
var batchCount int64 = 1
if rowsCount > RowsPerBatch {
batchCount = rowsCount / RowsPerBatch
} else {
return []Batch{{
Id: uuid.New(),
ShouldUseRange: false,
RetryCounter: 0,
}}, nil
}
batches, err := calculateBatchesMssql(ctx, db, job, batchCount)
if err != nil {
return nil, err
}
return batches, nil
}

View File

@@ -5,7 +5,7 @@ import (
"strings" "strings"
) )
func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool) string { func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool, isMinInclusive bool) string {
var sbQuery strings.Builder var sbQuery strings.Builder
sbQuery.WriteString("SELECT ") sbQuery.WriteString("SELECT ")
@@ -29,7 +29,14 @@ func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange
fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", job.Schema, job.Table) fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", job.Schema, job.Table)
if includeRange { if includeRange {
fmt.Fprintf(&sbQuery, " WHERE [%s] BETWEEN @minRange AND @maxRange", job.PrimaryKey) fmt.Fprintf(&sbQuery, " WHERE [%s]", job.PrimaryKey)
if isMinInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", job.PrimaryKey)
} }
fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", job.PrimaryKey) fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", job.PrimaryKey)

View File

@@ -1,91 +0,0 @@
package main
import (
"context"
"database/sql"
"fmt"
)
type BatchRange struct {
LowerLimit int
UpperLimit int
validRange bool
}
func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int, error) {
query := `
SELECT
SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
var rowsCount int
err := db.QueryRowContext(ctx, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func calculateChunkRangesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int) ([]BatchRange, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM
(SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS chunk_id FROM [%s].[%s]) AS T
GROUP BY chunk_id
ORDER BY chunk_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table)
rows, err := db.QueryContext(ctx, query, sql.Named("batchCount", batchCount))
if err != nil {
return nil, err
}
defer rows.Close()
batchRanges := make([]BatchRange, 0, batchCount)
for rows.Next() {
var br BatchRange
br.validRange = true
if err := rows.Scan(&br.LowerLimit, &br.UpperLimit); err != nil {
return nil, err
}
batchRanges = append(batchRanges, br)
}
if err := rows.Err(); err != nil {
return nil, err
}
return batchRanges, nil
}
const estimatedRowsPerBatch = 100_000
func calculateBatchMetrics(ctx context.Context, db *sql.DB, job MigrationJob) ([]BatchRange, error) {
rowsCount, err := estimateTotalRowsMssql(ctx, db, job)
if err != nil {
return nil, err
}
batchCount := 1
if rowsCount > estimatedRowsPerBatch {
batchCount = rowsCount / estimatedRowsPerBatch
} else {
return []BatchRange{{validRange: false}}, nil
}
chunksRange, err := calculateChunkRangesMssql(ctx, db, job, batchCount)
if err != nil {
return nil, err
}
return chunksRange, nil
}

View File

@@ -0,0 +1,66 @@
package main
import (
"fmt"
"github.com/google/uuid"
)
type ExtractorError struct {
Batch
LastId int64
HasLastId bool
Msg string
}
func (e *ExtractorError) Error() string {
return e.Msg
}
const maxRetryAttempts = 3
func extractorErrorHandler(chErrorsIn <-chan ExtractorError, chBatchesOut chan<- Batch, chGlobalErrorsOut chan<- error) {
for err := range chErrorsIn {
if err.RetryCounter >= maxRetryAttempts {
chGlobalErrorsOut <- fmt.Errorf("batch %v reached max retries (%d): %s", err.Id, maxRetryAttempts, err.Msg)
continue
}
newBatch := err.Batch
newBatch.RetryCounter++
if err.HasLastId {
newBatch.ParentId = err.Id
newBatch.Id = uuid.New()
newBatch.LowerLimit = err.LastId
newBatch.IsLowerLimitInclusive = false
}
chBatchesOut <- newBatch
}
}
func ExtractorErrorFromLastRowMssql(lastRow UnknownRowValues, indexPrimaryKey int, batch *Batch, previousError error) ExtractorError {
lastIdRawValue := lastRow[indexPrimaryKey]
lastId, ok := ToInt64(lastIdRawValue)
if !ok {
currentBatch := *batch
currentBatch.RetryCounter = maxRetryAttempts
exError := ExtractorError{
Batch: currentBatch,
HasLastId: true,
Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()),
}
return exError
}
exError := ExtractorError{
Batch: *batch,
HasLastId: true,
LastId: lastId,
Msg: previousError.Error(),
}
return exError
}

View File

@@ -3,6 +3,8 @@ package main
import ( import (
"context" "context"
"database/sql" "database/sql"
"slices"
"strings"
"time" "time"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
@@ -12,67 +14,126 @@ import (
type UnknownRowValues = []any type UnknownRowValues = []any
func extractFromMssql(ctx context.Context, db *sql.DB, job MigrationJob, columns []ColumnType, chunkSize int, batchRange BatchRange, out chan<- []UnknownRowValues) error { func extractFromMssql(
query := buildExtractQueryMssql(job, columns, batchRange.validRange) ctx context.Context,
log.Debug("Query used to extract data from mssql: ", query) db *sql.DB,
job MigrationJob,
columns []ColumnType,
chunkSize int,
chBatchesIn <-chan Batch,
chChunksOut chan<- []UnknownRowValues,
chErrorsOut chan<- ExtractorError,
) {
indexPrimaryKey := slices.IndexFunc(columns, func(col ColumnType) bool {
return strings.EqualFold(col.name, job.PrimaryKey)
})
var queryArgs []any if indexPrimaryKey == -1 {
if batchRange.validRange { exError := ExtractorError{
queryArgs = append(queryArgs, Batch: Batch{
sql.Named("minRange", batchRange.LowerLimit), RetryCounter: maxRetryAttempts,
sql.Named("maxRange", batchRange.UpperLimit), },
) HasLastId: false,
} Msg: "Primary key not found in columns provided",
queryStartTime := time.Now()
rows, err := db.QueryContext(ctx, query, queryArgs...)
if err != nil {
return err
}
defer rows.Close()
log.Debugf("Query executed in %v", time.Since(queryStartTime))
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
totalRowsExtracted := 0
chunkCount := 0
chunkStartTime := time.Now()
for rows.Next() {
values := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
if err := rows.Scan(scanArgs...); err != nil {
return err
}
rowsChunk = append(rowsChunk, values)
totalRowsExtracted++
if len(rowsChunk) >= chunkSize {
chunkCount++
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(chunkSize) / chunkDuration.Seconds()
log.Infof("Extracted chunk #%d: %d rows in %v (%.0f rows/sec) - Total: %d rows", chunkCount, len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
out <- rowsChunk
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
chunkStartTime = time.Now()
} }
chErrorsOut <- exError
return
} }
if len(rowsChunk) > 0 { for batch := range chBatchesIn {
chunkCount++ func() {
chunkDuration := time.Since(chunkStartTime) query := buildExtractQueryMssql(job, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive)
rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds() log.Debug("Query used to extract data from mssql: ", query)
log.Infof("Extracted final chunk #%d: %d rows in %v (%.0f rows/sec) - Total: %d rows",
chunkCount, len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
out <- rowsChunk
}
return rows.Err() var queryArgs []any
if batch.ShouldUseRange {
queryArgs = append(queryArgs,
sql.Named("min", batch.LowerLimit),
sql.Named("max", batch.UpperLimit),
)
}
queryStartTime := time.Now()
rows, err := db.QueryContext(ctx, query, queryArgs...)
if err != nil {
exError := ExtractorError{
Batch: batch,
HasLastId: false,
Msg: err.Error(),
}
chErrorsOut <- exError
return
}
defer rows.Close()
log.Debugf("Query executed in %v", time.Since(queryStartTime))
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
totalRowsExtracted := 0
chunkStartTime := time.Now()
for rows.Next() {
values := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
if err := rows.Scan(scanArgs...); err != nil {
if len(rowsChunk) == 0 {
exError := ExtractorError{
Batch: batch,
HasLastId: false,
Msg: err.Error(),
}
chErrorsOut <- exError
return
}
lastRow := rowsChunk[len(rowsChunk)-1]
chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err)
return
}
rowsChunk = append(rowsChunk, values)
totalRowsExtracted++
if len(rowsChunk) >= chunkSize {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(chunkSize) / chunkDuration.Seconds()
log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows",
len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
chChunksOut <- rowsChunk
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
chunkStartTime = time.Now()
}
}
if len(rowsChunk) > 0 {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds()
log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows",
len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
chChunksOut <- rowsChunk
}
if err := rows.Err(); err != nil {
if len(rowsChunk) == 0 {
exError := ExtractorError{
Batch: batch,
HasLastId: false,
Msg: err.Error(),
}
chErrorsOut <- exError
return
}
lastRow := rowsChunk[len(rowsChunk)-1]
chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err)
return
}
}()
}
} }
func extractFromPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error { func extractFromPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error {

View File

@@ -26,10 +26,12 @@ var migrationJobs []MigrationJob = []MigrationJob{
} }
const ( const (
NumExtractors int = 4 NumExtractors int = 4
NumLoaders int = 8 NumLoaders int = 8
ChunkSize int = 25000 ChunkSize int = 25000
QueueSize int = 8 QueueSize int = 8
ChunksPerBatch int = 16
RowsPerBatch int64 = int64(ChunkSize * ChunksPerBatch)
) )
func main() { func main() {

View File

@@ -25,39 +25,43 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration
logColumnTypes(targetColTypes, "Target col types") logColumnTypes(targetColTypes, "Target col types")
mssqlCtx := context.Background() mssqlCtx := context.Background()
batchRanges, err := calculateBatchMetrics(mssqlCtx, sourceDb, job) batches, err := batchGeneratorMssql(mssqlCtx, sourceDb, job)
if err != nil { if err != nil {
log.Error("Unexpected error calculating batch ranges: ", err) log.Error("Unexpected error calculating batch ranges: ", err)
} }
chBatchRanges := make(chan BatchRange, len(batchRanges)) chGlobalErrors := make(chan error)
defer close(chGlobalErrors)
maxExtractors := min(NumExtractors, len(batchRanges)) chBatches := make(chan Batch, len(batches))
chRowsExtract := make(chan []UnknownRowValues, QueueSize) chChunks := make(chan []UnknownRowValues, QueueSize)
chExtractorErrors := make(chan ExtractorError, len(batches))
maxExtractors := min(NumExtractors, len(batches))
var wgMssqlExtractors sync.WaitGroup var wgMssqlExtractors sync.WaitGroup
log.Infof("Starting %d MSSQL extractors...", maxExtractors) log.Infof("Starting %d MSSQL extractors...", maxExtractors)
extractStartTime := time.Now() extractStartTime := time.Now()
for range maxExtractors { for range maxExtractors {
wgMssqlExtractors.Go(func() { wgMssqlExtractors.Go(func() {
for br := range chBatchRanges { extractFromMssql(mssqlCtx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunks, chExtractorErrors)
if err := extractFromMssql(mssqlCtx, sourceDb, job, sourceColTypes, ChunkSize, br, chRowsExtract); err != nil {
log.Error("Unexpected error extracting data from mssql: ", err)
}
}
}) })
} }
go func() { go func() {
for _, br := range batchRanges { for _, br := range batches {
chBatchRanges <- br chBatches <- br
} }
close(chBatchRanges) close(chBatches)
close(chExtractorErrors)
}()
go func() {
extractorErrorHandler(chExtractorErrors, chBatches, chGlobalErrors)
}() }()
go func() { go func() {
wgMssqlExtractors.Wait() wgMssqlExtractors.Wait()
close(chRowsExtract) close(chChunks)
log.Infof("Extraction completed in %v", time.Since(extractStartTime)) log.Infof("Extraction completed in %v", time.Since(extractStartTime))
}() }()
@@ -68,7 +72,7 @@ func processMigrationJob(sourceDb *sql.DB, targetDb *pgxpool.Pool, job Migration
transformStartTime := time.Now() transformStartTime := time.Now()
for range maxExtractors { for range maxExtractors {
wgMssqlTransformers.Go(func() { wgMssqlTransformers.Go(func() {
transformRowsMssql(sourceColTypes, chRowsExtract, chRowsTransform) transformRowsMssql(sourceColTypes, chChunks, chRowsTransform)
}) })
} }

View File

@@ -43,3 +43,20 @@ func transformRowsMssql(columns []ColumnType, in <-chan []UnknownRowValues, out
out <- rows out <- rows
} }
} }
func ToInt64(v any) (int64, bool) {
switch t := v.(type) {
case int:
return int64(t), true
case int8:
return int64(t), true
case int16:
return int64(t), true
case int32:
return int64(t), true
case int64:
return int64(t), true
default:
return 0, false
}
}