refactor: simplify batch processing by removing partition dependency and introducing batch accumulator

This commit is contained in:
2026-05-11 00:38:42 -05:00
parent 16217f6ee2
commit ab9a2d8694
5 changed files with 81 additions and 76 deletions

View File

@@ -27,7 +27,6 @@ func sendBatch(ctx context.Context, chBatchesOut chan<- models.Batch, batch mode
func flush( func flush(
ctx context.Context, ctx context.Context,
partition *models.Partition,
batchSize int, batchSize int,
batchRows []models.UnknownRowValues, batchRows []models.UnknownRowValues,
chBatchesOut chan<- models.Batch, chBatchesOut chan<- models.Batch,
@@ -36,7 +35,7 @@ func flush(
return nil return nil
} }
batch := models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows} batch := models.Batch{Id: uuid.New(), Rows: batchRows}
batchRows = make([]models.UnknownRowValues, 0, batchSize) batchRows = make([]models.UnknownRowValues, 0, batchSize)
return sendBatch(ctx, chBatchesOut, batch) return sendBatch(ctx, chBatchesOut, batch)
} }

View File

@@ -90,7 +90,7 @@ func (ex *GenericExtractor) ProcessPartition(
return rowsRead, err return rowsRead, err
} }
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err return rowsRead, err
} }
@@ -102,7 +102,7 @@ func (ex *GenericExtractor) ProcessPartition(
batchRows = append(batchRows, rowValues) batchRows = append(batchRows, rowValues)
if len(batchRows) >= batchSize { if len(batchRows) >= batchSize {
// logrus.Debugf("Batch size reached, flushing batch with %v rows (rowsRead=%v)", len(batchRows), rowsRead) // logrus.Debugf("Batch size reached, flushing batch with %v rows (rowsRead=%v)", len(batchRows), rowsRead)
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
// logrus.Warnf("Error flushing rows: %v", err) // logrus.Warnf("Error flushing rows: %v", err)
return rowsRead, err return rowsRead, err
} }
@@ -110,7 +110,7 @@ func (ex *GenericExtractor) ProcessPartition(
} }
} }
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err return rowsRead, err
} }

View File

@@ -30,7 +30,7 @@ func (gl *GenericLoader) Consume(
}) })
var accRows []models.UnknownRowValues var accRows []models.UnknownRowValues
var parentBatchesId []uuid.UUID var parentBatches []models.BatchRef
pendingDone := 0 pendingDone := 0
defer func() { defer func() {
@@ -43,11 +43,11 @@ func (gl *GenericLoader) Consume(
if len(accRows) == 0 { if len(accRows) == 0 {
return true return true
} }
count := len(parentBatchesId) count := len(parentBatches)
superBatch := models.Batch{ superBatch := models.Batch{
Id: uuid.New(), Id: uuid.New(),
ParentBatchesId: parentBatchesId, ParentBatches: parentBatches,
Rows: accRows, Rows: accRows,
} }
processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch) processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch)
for range count { for range count {
@@ -55,7 +55,7 @@ func (gl *GenericLoader) Consume(
} }
pendingDone -= count pendingDone -= count
accRows = nil accRows = nil
parentBatchesId = nil parentBatches = nil
if err != nil { if err != nil {
atomic.AddInt32(failedBatchesCount, 1) atomic.AddInt32(failedBatchesCount, 1)
@@ -142,7 +142,7 @@ func (gl *GenericLoader) Consume(
pendingDone++ pendingDone++
accRows = append(accRows, batch.Rows...) accRows = append(accRows, batch.Rows...)
parentBatchesId = append(parentBatchesId, batch.Id) parentBatches = append(parentBatches, models.BatchRef{Id: batch.Id})
if len(accRows) >= batchSize { if len(accRows) >= batchSize {
if !flush() { if !flush() {

View File

@@ -11,6 +11,58 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
type batchAccumulator struct {
batchSize int
rows []models.UnknownRowValues
parents []models.BatchRef
}
func (a *batchAccumulator) add(batch models.Batch) {
a.rows = append(a.rows, batch.Rows...)
a.parents = append(a.parents, models.BatchRef{Id: batch.Id})
}
func (a *batchAccumulator) ready() bool {
return len(a.rows) >= a.batchSize
}
func (a *batchAccumulator) flush(ctx context.Context, chOut chan<- models.Batch, wg *sync.WaitGroup) bool {
if len(a.rows) == 0 {
return true
}
out := models.Batch{
Id: uuid.New(),
ParentBatches: a.parents,
Rows: a.rows,
}
wg.Add(1)
select {
case chOut <- out:
case <-ctx.Done():
wg.Done()
return false
}
a.rows = nil
a.parents = nil
return true
}
func sendTransformError(ctx context.Context, err error, ch chan<- custom_errors.JobError) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
var jobErr custom_errors.JobError
if je, ok := errors.AsType[*custom_errors.JobError](err); ok {
jobErr = *je
} else {
jobErr = custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}
}
select {
case ch <- jobErr:
case <-ctx.Done():
}
}
func (mssqlTr *MssqlTransformer) Consume( func (mssqlTr *MssqlTransformer) Consume(
ctx context.Context, ctx context.Context,
columns []models.ColumnType, columns []models.ColumnType,
@@ -25,90 +77,40 @@ func (mssqlTr *MssqlTransformer) Consume(
storagePlan := computeStorageTransformationPlan(ctx, mssqlTr.azureClient, mssqlTr.toStorage, columns, mssqlTr.sourceTable) storagePlan := computeStorageTransformationPlan(ctx, mssqlTr.azureClient, mssqlTr.toStorage, columns, mssqlTr.sourceTable)
transformationPlan = append(transformationPlan, storagePlan...) transformationPlan = append(transformationPlan, storagePlan...)
var accRows []models.UnknownRowValues acc := &batchAccumulator{batchSize: batchSize}
var parentBatchesId []uuid.UUID
var firstPartitionId uuid.UUID
flush := func() bool {
if len(accRows) == 0 {
return true
}
out := models.Batch{
Id: uuid.New(),
PartitionId: firstPartitionId,
ParentBatchesId: parentBatchesId,
Rows: accRows,
}
select {
case chBatchesOut <- out:
wgActiveBatches.Add(1)
case <-ctx.Done():
return false
}
accRows = nil
parentBatchesId = nil
firstPartitionId = uuid.Nil
return true
}
for { for {
if ctx.Err() != nil {
return
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case batch, ok := <-chBatchesIn: case batch, ok := <-chBatchesIn:
if !ok { if !ok {
flush() acc.flush(ctx, chBatchesOut, wgActiveBatches)
return return
} }
if len(transformationPlan) > 0 { if len(transformationPlan) > 0 {
err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig) if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil {
if err != nil { sendTransformError(ctx, err, chJobErrorsOut)
if errors.Is(err, ctx.Err()) {
return
}
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
select {
case chJobErrorsOut <- *jobError:
case <-ctx.Done():
return
}
} else {
select {
case chJobErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}:
case <-ctx.Done():
return
}
}
return return
} }
} }
if batchSize <= 0 { if batchSize <= 0 {
wgActiveBatches.Add(1)
select { select {
case chBatchesOut <- batch: case chBatchesOut <- batch:
wgActiveBatches.Add(1)
case <-ctx.Done(): case <-ctx.Done():
wgActiveBatches.Done()
return return
} }
continue continue
} }
if len(parentBatchesId) == 0 { acc.add(batch)
firstPartitionId = batch.PartitionId if acc.ready() {
} if !acc.flush(ctx, chBatchesOut, wgActiveBatches) {
accRows = append(accRows, batch.Rows...)
parentBatchesId = append(parentBatchesId, batch.Id)
if len(accRows) >= batchSize {
if !flush() {
return return
} }
} }

View File

@@ -8,12 +8,16 @@ import (
type UnknownRowValues = []any type UnknownRowValues = []any
type BatchRef struct {
Id uuid.UUID
PartitionId uuid.UUID
}
type Batch struct { type Batch struct {
Id uuid.UUID Id uuid.UUID
PartitionId uuid.UUID ParentBatches []BatchRef
ParentBatchesId []uuid.UUID Rows []UnknownRowValues
Rows []UnknownRowValues RetryCounter int
RetryCounter int
} }
type PartitionRange struct { type PartitionRange struct {