package loaders import ( "context" "errors" "sync" "sync/atomic" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" "github.com/sirupsen/logrus" ) type loaderAccumulator struct { batchSize int rows []models.UnknownRowValues parents []models.BatchRef pendingDone int } func (a *loaderAccumulator) add(batch models.Batch) { a.rows = append(a.rows, batch.Rows...) a.parents = append(a.parents, models.BatchRef{Id: batch.Id}) a.pendingDone++ } func (a *loaderAccumulator) ready() bool { return len(a.rows) >= a.batchSize } func (a *loaderAccumulator) drainPending(wg *sync.WaitGroup) { for range a.pendingDone { wg.Done() } } func sendLoadError( ctx context.Context, err error, retryConfig config.RetryConfig, failedBatchesCount *int32, chErrorsOut chan<- custom_errors.JobError, ) bool { atomic.AddInt32(failedBatchesCount, 1) var jobErr custom_errors.JobError if je, ok := errors.AsType[*custom_errors.JobError](err); ok { jobErr = *je } else { jobErr = custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err} } select { case <-ctx.Done(): return false case chErrorsOut <- jobErr: } if atomic.LoadInt32(failedBatchesCount) > int32(retryConfig.MaxFailedBatchesLoad) { select { case <-ctx.Done(): case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed batches (load) reached"}: } return false } return true } func (gl *GenericLoader) Consume( ctx context.Context, tableInfo config.TargetTableInfo, columns []models.ColumnType, retryConfig config.RetryConfig, batchSize int, chBatchesIn <-chan models.Batch, chErrorsOut chan<- custom_errors.JobError, wgActiveBatches *sync.WaitGroup, rowsLoaded *int64, failedBatchesCount *int32, ) { colNames := mapSlice(columns, func(col models.ColumnType) string { return col.Name() }) acc := &loaderAccumulator{batchSize: batchSize} defer acc.drainPending(wgActiveBatches) flush := func() bool { if len(acc.rows) == 0 { return true } count := len(acc.parents) superBatch := models.Batch{ Id: uuid.New(), ParentBatches: acc.parents, Rows: acc.rows, } processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch) for range count { wgActiveBatches.Done() } acc.pendingDone -= count acc.rows = nil acc.parents = nil if err != nil { return sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut) } current := atomic.LoadInt64(rowsLoaded) logrus.Debugf("Rows loaded (batch loaded): +%v [current=%v] (%s.%s)", processedRows, current, tableInfo.Schema, tableInfo.Table) atomic.AddInt64(rowsLoaded, int64(processedRows)) return true } for { select { case <-ctx.Done(): return case batch, ok := <-chBatchesIn: if !ok { flush() return } if batchSize <= 0 { processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, batch) wgActiveBatches.Done() if err != nil { if !sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut) { return } continue } current := atomic.LoadInt64(rowsLoaded) logrus.Debugf("Rows loaded: +%v [current=%v] (%s.%s)", processedRows, current, tableInfo.Schema, tableInfo.Table) atomic.AddInt64(rowsLoaded, int64(processedRows)) continue } acc.add(batch) if acc.ready() { if !flush() { return } } } } }