From 604702ef4313e6195b38a1f0582ecf613d1bfa3b Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Mon, 11 May 2026 08:36:54 -0500 Subject: [PATCH] refactor: add unit tests for loaderAccumulator and consume functions; enhance error handling and batch processing logic --- internal/app/etl/loaders/consume_test.go | 603 +++++++++++++++++++++++ 1 file changed, 603 insertions(+) create mode 100644 internal/app/etl/loaders/consume_test.go diff --git a/internal/app/etl/loaders/consume_test.go b/internal/app/etl/loaders/consume_test.go new file mode 100644 index 0000000..4515269 --- /dev/null +++ b/internal/app/etl/loaders/consume_test.go @@ -0,0 +1,603 @@ +package loaders + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" + "github.com/google/uuid" +) + +const testTimeout = 2 * time.Second + +type mockResult struct { + err error +} + +type mockDbWrapper struct { + mu sync.Mutex + callCount int + results []mockResult +} + +func newMockDb(results ...mockResult) *mockDbWrapper { + return &mockDbWrapper{results: results} +} + +func (m *mockDbWrapper) SaveMassive(_ context.Context, _ string, _ string, _ []string, rows [][]any) (int64, error) { + m.mu.Lock() + defer m.mu.Unlock() + idx := m.callCount + m.callCount++ + if idx < len(m.results) && m.results[idx].err != nil { + return 0, m.results[idx].err + } + return int64(len(rows)), nil +} + +func (m *mockDbWrapper) Close() error { return nil } +func (m *mockDbWrapper) Connect(_ context.Context, _ string) error { return nil } +func (m *mockDbWrapper) Exec(_ context.Context, _ string, _ ...any) (dbwrapper.ExecResult, error) { + return dbwrapper.ExecResult{}, nil +} +func (m *mockDbWrapper) GetDialect() string { return "" } +func (m *mockDbWrapper) Query(_ context.Context, _ string, _ ...any) (dbwrapper.RowsResult, error) { + return nil, nil +} +func (m *mockDbWrapper) QueryRow(_ context.Context, _ string, _ ...any) dbwrapper.RowResult { + return nil +} +func (m *mockDbWrapper) QueryFromObject(_ context.Context, _ dbwrapper.ExtractionQuery) (dbwrapper.RowsResult, error) { + return nil, nil +} + +func makeBatch(numRows int) models.Batch { + rows := make([]models.UnknownRowValues, numRows) + for i := range rows { + rows[i] = models.UnknownRowValues{i} + } + return models.Batch{Id: uuid.New(), Rows: rows} +} + +func newLoader(db *mockDbWrapper) GenericLoader { + return GenericLoader{db: db} +} + +func rc(maxFailed int) config.RetryConfig { + return config.RetryConfig{Attempts: 1, MaxFailedBatchesLoad: maxFailed} +} + +func sendBatch(chIn chan<- models.Batch, batch models.Batch, wg *sync.WaitGroup) { + wg.Add(1) + chIn <- batch +} + +func runConsume( + ctx context.Context, + gl GenericLoader, + retryConfig config.RetryConfig, + batchSize int, + chIn <-chan models.Batch, + chErr chan<- custom_errors.JobError, + wg *sync.WaitGroup, + rowsLoaded *int64, + failedCount *int32, +) <-chan struct{} { + done := make(chan struct{}) + go func() { + gl.Consume(ctx, config.TargetTableInfo{}, nil, retryConfig, batchSize, + chIn, chErr, wg, rowsLoaded, failedCount) + close(done) + }() + return done +} + +func waitWg(wg *sync.WaitGroup) <-chan struct{} { + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + return done +} + +func dbError() error { return errors.New("connection reset by peer") } + +func TestLoaderAccumulator_Add(t *testing.T) { + acc := &loaderAccumulator{batchSize: 5} + b1 := makeBatch(2) + b2 := makeBatch(3) + + acc.add(b1) + acc.add(b2) + + if len(acc.rows) != 5 { + t.Errorf("expected 5 rows, got %d", len(acc.rows)) + } + if len(acc.parents) != 2 { + t.Fatalf("expected 2 parents, got %d", len(acc.parents)) + } + if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id { + t.Error("parent IDs do not match source batch IDs in order") + } + if acc.pendingDone != 2 { + t.Errorf("expected pendingDone=2, got %d", acc.pendingDone) + } +} + +func TestLoaderAccumulator_Ready(t *testing.T) { + acc := &loaderAccumulator{batchSize: 3} + acc.add(makeBatch(2)) + if acc.ready() { + t.Error("should not be ready with 2 rows and batchSize=3") + } + acc.add(makeBatch(1)) + if !acc.ready() { + t.Error("should be ready with 3 rows and batchSize=3") + } +} + +func TestLoaderAccumulator_DrainPending_ReleasesWg(t *testing.T) { + acc := &loaderAccumulator{batchSize: 5, pendingDone: 3} + var wg sync.WaitGroup + wg.Add(3) + + acc.drainPending(&wg) + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out: drainPending did not call Done() enough times") + } +} + +func TestLoaderAccumulator_DrainPending_ZeroPending(t *testing.T) { + acc := &loaderAccumulator{batchSize: 5, pendingDone: 0} + var wg sync.WaitGroup + + acc.drainPending(&wg) + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out") + } +} + +func TestSendLoadError_PlainError_WrappedAsNonFatal(t *testing.T) { + ch := make(chan custom_errors.JobError, 2) + var failedCount int32 + + result := sendLoadError(context.Background(), errors.New("db error"), rc(10), &failedCount, ch) + + if !result { + t.Error("expected true (below threshold)") + } + if atomic.LoadInt32(&failedCount) != 1 { + t.Errorf("expected failedCount=1, got %d", failedCount) + } + select { + case e := <-ch: + if e.ShouldCancelJob { + t.Error("plain error should be wrapped as ShouldCancelJob=false") + } + default: + t.Error("expected an error in the channel") + } +} + +func TestSendLoadError_JobError_PassesThrough(t *testing.T) { + ch := make(chan custom_errors.JobError, 2) + var failedCount int32 + original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"} + + sendLoadError(context.Background(), original, rc(10), &failedCount, ch) + + select { + case e := <-ch: + if e.Msg != "custom msg" || e.ShouldCancelJob { + t.Errorf("JobError should pass through unchanged, got %+v", e) + } + default: + t.Error("expected an error in the channel") + } +} + +func TestSendLoadError_FatalJobError_BelowThreshold_ReturnsTrue(t *testing.T) { + ch := make(chan custom_errors.JobError, 2) + var failedCount int32 + fatal := &custom_errors.JobError{ShouldCancelJob: true, Msg: "unique constraint"} + + result := sendLoadError(context.Background(), fatal, rc(10), &failedCount, ch) + + if !result { + t.Error("below-threshold fatal error should return true (external cancel expected from JobErrorHandler)") + } + select { + case e := <-ch: + if !e.ShouldCancelJob { + t.Error("fatal JobError should be forwarded with ShouldCancelJob=true") + } + default: + t.Error("expected the fatal error in the channel") + } +} + +func TestSendLoadError_ThresholdExceeded_ReturnsFalse(t *testing.T) { + ch := make(chan custom_errors.JobError, 2) + var failedCount int32 + + result := sendLoadError(context.Background(), errors.New("db error"), rc(0), &failedCount, ch) + + if result { + t.Error("expected false when threshold exceeded") + } + if len(ch) != 2 { + t.Fatalf("expected 2 errors (batch error + fatal threshold error), got %d", len(ch)) + } + <-ch // batch error + threshold := <-ch + if !threshold.ShouldCancelJob { + t.Error("second error should be the fatal threshold error (ShouldCancelJob=true)") + } +} + +func TestSendLoadError_AtThresholdBoundary(t *testing.T) { + ch := make(chan custom_errors.JobError, 6) + var failedCount int32 + + if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { + t.Error("first failure: expected true (below threshold)") + } + if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { + t.Error("second failure: expected true (at threshold, not exceeded)") + } + if sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { + t.Error("third failure: expected false (threshold exceeded)") + } +} + +func TestSendLoadError_ContextCancelled_ReturnsFalse(t *testing.T) { + ch := make(chan custom_errors.JobError) + var failedCount int32 + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result := sendLoadError(ctx, errors.New("db error"), rc(10), &failedCount, ch) + + if result { + t.Error("expected false when context is cancelled") + } + if len(ch) != 0 { + t.Error("no error should be sent when context is cancelled") + } +} + +func TestConsume_Passthrough_RowsLoaded(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(5), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 5 { + t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded) + } + if db.callCount != 1 { + t.Errorf("expected 1 SaveMassive call, got %d", db.callCount) + } +} + +func TestConsume_Passthrough_MultipleBatches_RowsAccumulate(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch, 3) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(3), &wg) + sendBatch(chIn, makeBatch(2), &wg) + sendBatch(chIn, makeBatch(4), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 9 { + t.Errorf("expected rowsLoaded=9, got %d", rowsLoaded) + } +} + +func TestConsume_Passthrough_WgDoneBeforeErrorHandling(t *testing.T) { + db := newMockDb(mockResult{err: dbError()}) + gl := newLoader(db) + chIn := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 2) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(2), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out: Done() was not called even though processing failed") + } +} + +func TestConsume_Passthrough_NonFatalError_Continues(t *testing.T) { + db := newMockDb(mockResult{err: dbError()}) + gl := newLoader(db) + chIn := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 3) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(2), &wg) + sendBatch(chIn, makeBatch(3), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 3 { + t.Errorf("expected rowsLoaded=3 (only second batch succeeded), got %d", rowsLoaded) + } + if atomic.LoadInt32(&failedCount) != 1 { + t.Errorf("expected failedCount=1, got %d", failedCount) + } + if len(chErr) == 0 { + t.Error("expected at least one error in chErr for the failed batch") + } +} + +func TestConsume_Passthrough_ThresholdExceeded_Exits(t *testing.T) { + db := newMockDb(mockResult{err: dbError()}) + gl := newLoader(db) + chIn := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 3) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(1), &wg) + + done := runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("Consume did not exit after threshold exceeded") + } + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out after threshold exit") + } +} + +func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch, 3) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(1), &wg) + sendBatch(chIn, makeBatch(1), &wg) + sendBatch(chIn, makeBatch(1), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(0), 3, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 3 { + t.Errorf("expected rowsLoaded=3, got %d", rowsLoaded) + } + if db.callCount != 1 { + t.Errorf("expected 1 SaveMassive call, got %d", db.callCount) + } +} + +func TestConsume_Accumulation_FlushOnClose(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(2), &wg) + sendBatch(chIn, makeBatch(3), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 5 { + t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded) + } + if db.callCount != 1 { + t.Errorf("expected exactly 1 SaveMassive call (single flush on close), got %d", db.callCount) + } +} + +func TestConsume_Accumulation_RowsLoadedCorrect(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch, 5) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + for range 5 { + sendBatch(chIn, makeBatch(2), &wg) + } + close(chIn) + + <-runConsume(context.Background(), gl, rc(0), 4, chIn, chErr, &wg, &rowsLoaded, &failedCount) + wg.Wait() + + if rowsLoaded != 10 { + t.Errorf("expected rowsLoaded=10, got %d", rowsLoaded) + } + if db.callCount != 3 { + t.Errorf("expected 3 SaveMassive calls (2 threshold flushes + 1 on close), got %d", db.callCount) + } +} + +func TestConsume_Accumulation_WgBalanced_OnContextCancel(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + ctx, cancel := context.WithCancel(context.Background()) + done := runConsume(ctx, gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + sendBatch(chIn, makeBatch(1), &wg) + sendBatch(chIn, makeBatch(1), &wg) + cancel() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("Consume did not exit after context cancellation") + } + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out: drainPending did not release accumulated batches on cancel") + } +} + +func TestConsume_Accumulation_ErrorInFlush_WgStillBalanced(t *testing.T) { + db := newMockDb(mockResult{err: dbError()}) + gl := newLoader(db) + chIn := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 3) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + sendBatch(chIn, makeBatch(1), &wg) + sendBatch(chIn, makeBatch(1), &wg) + close(chIn) + + <-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out: wg.Done() not called after flush error") + } +} + +func TestConsume_Accumulation_MultipleFlushes_NonFatalErrors(t *testing.T) { + db := newMockDb(mockResult{err: dbError()}, mockResult{err: dbError()}) + gl := newLoader(db) + chIn := make(chan models.Batch, 4) + chErr := make(chan custom_errors.JobError, 6) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + for range 4 { + sendBatch(chIn, makeBatch(1), &wg) + } + close(chIn) + + <-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + select { + case <-waitWg(&wg): + case <-time.After(testTimeout): + t.Fatal("wg.Wait() timed out") + } + + if atomic.LoadInt32(&failedCount) != 2 { + t.Errorf("expected failedCount=2, got %d", failedCount) + } + if rowsLoaded != 0 { + t.Errorf("expected rowsLoaded=0 (all batches failed), got %d", rowsLoaded) + } +} + +func TestConsume_EmptyInput_NoProcessing(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + close(chIn) + + done := runConsume(context.Background(), gl, rc(0), 5, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("Consume did not exit after empty input channel was closed") + } + + if db.callCount != 0 { + t.Errorf("expected no SaveMassive calls, got %d", db.callCount) + } + if rowsLoaded != 0 { + t.Errorf("expected rowsLoaded=0, got %d", rowsLoaded) + } + wg.Wait() +} + +func TestConsume_ContextCancellation_Exits(t *testing.T) { + db := newMockDb() + gl := newLoader(db) + chIn := make(chan models.Batch) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + var rowsLoaded int64 + var failedCount int32 + + ctx, cancel := context.WithCancel(context.Background()) + done := runConsume(ctx, gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) + + cancel() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("Consume did not exit after context cancellation") + } + wg.Wait() +}