diff --git a/internal/app/etl/extractors/main.go b/internal/app/etl/extractors/main.go index 3081103..26e1d1c 100644 --- a/internal/app/etl/extractors/main.go +++ b/internal/app/etl/extractors/main.go @@ -27,7 +27,6 @@ func sendBatch(ctx context.Context, chBatchesOut chan<- models.Batch, batch mode func flush( ctx context.Context, - partition *models.Partition, batchSize int, batchRows []models.UnknownRowValues, chBatchesOut chan<- models.Batch, @@ -36,7 +35,7 @@ func flush( return nil } - batch := models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows} + batch := models.Batch{Id: uuid.New(), Rows: batchRows} batchRows = make([]models.UnknownRowValues, 0, batchSize) return sendBatch(ctx, chBatchesOut, batch) } diff --git a/internal/app/etl/extractors/process.go b/internal/app/etl/extractors/process.go index 1a56912..7adc01d 100644 --- a/internal/app/etl/extractors/process.go +++ b/internal/app/etl/extractors/process.go @@ -90,7 +90,7 @@ func (ex *GenericExtractor) ProcessPartition( return rowsRead, err } - if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { + if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil { return rowsRead, err } @@ -102,7 +102,7 @@ func (ex *GenericExtractor) ProcessPartition( batchRows = append(batchRows, rowValues) if len(batchRows) >= batchSize { // logrus.Debugf("Batch size reached, flushing batch with %v rows (rowsRead=%v)", len(batchRows), rowsRead) - if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { + if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil { // logrus.Warnf("Error flushing rows: %v", err) return rowsRead, err } @@ -110,7 +110,7 @@ func (ex *GenericExtractor) ProcessPartition( } } - if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { + if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil { return rowsRead, err } diff --git a/internal/app/etl/loaders/consume.go b/internal/app/etl/loaders/consume.go index 0b93aa9..e4dfdf4 100644 --- a/internal/app/etl/loaders/consume.go +++ b/internal/app/etl/loaders/consume.go @@ -30,7 +30,7 @@ func (gl *GenericLoader) Consume( }) var accRows []models.UnknownRowValues - var parentBatchesId []uuid.UUID + var parentBatches []models.BatchRef pendingDone := 0 defer func() { @@ -43,11 +43,11 @@ func (gl *GenericLoader) Consume( if len(accRows) == 0 { return true } - count := len(parentBatchesId) + count := len(parentBatches) superBatch := models.Batch{ - Id: uuid.New(), - ParentBatchesId: parentBatchesId, - Rows: accRows, + Id: uuid.New(), + ParentBatches: parentBatches, + Rows: accRows, } processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch) for range count { @@ -55,7 +55,7 @@ func (gl *GenericLoader) Consume( } pendingDone -= count accRows = nil - parentBatchesId = nil + parentBatches = nil if err != nil { atomic.AddInt32(failedBatchesCount, 1) @@ -142,7 +142,7 @@ func (gl *GenericLoader) Consume( pendingDone++ accRows = append(accRows, batch.Rows...) - parentBatchesId = append(parentBatchesId, batch.Id) + parentBatches = append(parentBatches, models.BatchRef{Id: batch.Id}) if len(accRows) >= batchSize { if !flush() { diff --git a/internal/app/etl/transformers/consume.go b/internal/app/etl/transformers/consume.go index bd3a92d..ae65555 100644 --- a/internal/app/etl/transformers/consume.go +++ b/internal/app/etl/transformers/consume.go @@ -11,6 +11,58 @@ import ( "github.com/google/uuid" ) +type batchAccumulator struct { + batchSize int + rows []models.UnknownRowValues + parents []models.BatchRef +} + +func (a *batchAccumulator) add(batch models.Batch) { + a.rows = append(a.rows, batch.Rows...) + a.parents = append(a.parents, models.BatchRef{Id: batch.Id}) +} + +func (a *batchAccumulator) ready() bool { + return len(a.rows) >= a.batchSize +} + +func (a *batchAccumulator) flush(ctx context.Context, chOut chan<- models.Batch, wg *sync.WaitGroup) bool { + if len(a.rows) == 0 { + return true + } + out := models.Batch{ + Id: uuid.New(), + ParentBatches: a.parents, + Rows: a.rows, + } + wg.Add(1) + select { + case chOut <- out: + case <-ctx.Done(): + wg.Done() + return false + } + a.rows = nil + a.parents = nil + return true +} + +func sendTransformError(ctx context.Context, err error, ch chan<- custom_errors.JobError) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + var jobErr custom_errors.JobError + if je, ok := errors.AsType[*custom_errors.JobError](err); ok { + jobErr = *je + } else { + jobErr = custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err} + } + select { + case ch <- jobErr: + case <-ctx.Done(): + } +} + func (mssqlTr *MssqlTransformer) Consume( ctx context.Context, columns []models.ColumnType, @@ -25,90 +77,40 @@ func (mssqlTr *MssqlTransformer) Consume( storagePlan := computeStorageTransformationPlan(ctx, mssqlTr.azureClient, mssqlTr.toStorage, columns, mssqlTr.sourceTable) transformationPlan = append(transformationPlan, storagePlan...) - var accRows []models.UnknownRowValues - var parentBatchesId []uuid.UUID - var firstPartitionId uuid.UUID - - flush := func() bool { - if len(accRows) == 0 { - return true - } - out := models.Batch{ - Id: uuid.New(), - PartitionId: firstPartitionId, - ParentBatchesId: parentBatchesId, - Rows: accRows, - } - select { - case chBatchesOut <- out: - wgActiveBatches.Add(1) - case <-ctx.Done(): - return false - } - accRows = nil - parentBatchesId = nil - firstPartitionId = uuid.Nil - return true - } + acc := &batchAccumulator{batchSize: batchSize} for { - if ctx.Err() != nil { - return - } - select { case <-ctx.Done(): return case batch, ok := <-chBatchesIn: if !ok { - flush() + acc.flush(ctx, chBatchesOut, wgActiveBatches) return } if len(transformationPlan) > 0 { - err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig) - if err != nil { - if errors.Is(err, ctx.Err()) { - return - } - - if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok { - select { - case chJobErrorsOut <- *jobError: - case <-ctx.Done(): - return - } - } else { - select { - case chJobErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}: - case <-ctx.Done(): - return - } - } - + if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil { + sendTransformError(ctx, err, chJobErrorsOut) return } } if batchSize <= 0 { + wgActiveBatches.Add(1) select { case chBatchesOut <- batch: - wgActiveBatches.Add(1) case <-ctx.Done(): + wgActiveBatches.Done() return } continue } - if len(parentBatchesId) == 0 { - firstPartitionId = batch.PartitionId - } - accRows = append(accRows, batch.Rows...) - parentBatchesId = append(parentBatchesId, batch.Id) - - if len(accRows) >= batchSize { - if !flush() { + acc.add(batch) + if acc.ready() { + if !acc.flush(ctx, chBatchesOut, wgActiveBatches) { return } } diff --git a/internal/app/models/main.go b/internal/app/models/main.go index 5becf6a..42558a9 100644 --- a/internal/app/models/main.go +++ b/internal/app/models/main.go @@ -8,12 +8,16 @@ import ( type UnknownRowValues = []any +type BatchRef struct { + Id uuid.UUID + PartitionId uuid.UUID +} + type Batch struct { - Id uuid.UUID - PartitionId uuid.UUID - ParentBatchesId []uuid.UUID - Rows []UnknownRowValues - RetryCounter int + Id uuid.UUID + ParentBatches []BatchRef + Rows []UnknownRowValues + RetryCounter int } type PartitionRange struct {