package loaders import ( "context" "errors" "sync" "sync/atomic" "testing" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) const testTimeout = 2 * time.Second type mockResult struct { err error } type mockDbWrapper struct { mu sync.Mutex callCount int results []mockResult } func newMockDb(results ...mockResult) *mockDbWrapper { return &mockDbWrapper{results: results} } func (m *mockDbWrapper) SaveMassive(_ context.Context, _ string, _ string, _ []string, rows [][]any) (int64, error) { m.mu.Lock() defer m.mu.Unlock() idx := m.callCount m.callCount++ if idx < len(m.results) && m.results[idx].err != nil { return 0, m.results[idx].err } return int64(len(rows)), nil } func (m *mockDbWrapper) Close() error { return nil } func (m *mockDbWrapper) Connect(_ context.Context, _ string) error { return nil } func (m *mockDbWrapper) Exec(_ context.Context, _ string, _ ...any) (dbwrapper.ExecResult, error) { return dbwrapper.ExecResult{}, nil } func (m *mockDbWrapper) GetDialect() string { return "" } func (m *mockDbWrapper) Query(_ context.Context, _ string, _ ...any) (dbwrapper.RowsResult, error) { return nil, nil } func (m *mockDbWrapper) QueryRow(_ context.Context, _ string, _ ...any) dbwrapper.RowResult { return nil } func (m *mockDbWrapper) QueryFromObject(_ context.Context, _ dbwrapper.ExtractionQuery) (dbwrapper.RowsResult, error) { return nil, nil } func makeBatch(numRows int) models.Batch { rows := make([]models.UnknownRowValues, numRows) for i := range rows { rows[i] = models.UnknownRowValues{i} } return models.Batch{Id: uuid.New(), Rows: rows} } func newLoader(db *mockDbWrapper) GenericLoader { return GenericLoader{db: db} } func rc(maxFailed int) config.RetryConfig { return config.RetryConfig{Attempts: 1, MaxFailedBatchesLoad: maxFailed} } func sendBatch(chIn chan<- models.Batch, batch models.Batch, wg *sync.WaitGroup) { wg.Add(1) chIn <- batch } func runConsume( ctx context.Context, gl GenericLoader, retryConfig config.RetryConfig, batchSize int, chIn <-chan models.Batch, chErr chan<- custom_errors.JobError, wg *sync.WaitGroup, rowsLoaded *int64, failedCount *int32, ) <-chan struct{} { done := make(chan struct{}) go func() { gl.Consume(ctx, config.TargetTableInfo{}, nil, retryConfig, batchSize, chIn, chErr, wg, rowsLoaded, failedCount) close(done) }() return done } func waitWg(wg *sync.WaitGroup) <-chan struct{} { done := make(chan struct{}) go func() { wg.Wait(); close(done) }() return done } func dbError() error { return errors.New("connection reset by peer") } func TestLoaderAccumulator_Add(t *testing.T) { acc := &loaderAccumulator{batchSize: 5} b1 := makeBatch(2) b2 := makeBatch(3) acc.add(b1) acc.add(b2) if len(acc.rows) != 5 { t.Errorf("expected 5 rows, got %d", len(acc.rows)) } if len(acc.parents) != 2 { t.Fatalf("expected 2 parents, got %d", len(acc.parents)) } if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id { t.Error("parent IDs do not match source batch IDs in order") } if acc.pendingDone != 2 { t.Errorf("expected pendingDone=2, got %d", acc.pendingDone) } } func TestLoaderAccumulator_Ready(t *testing.T) { acc := &loaderAccumulator{batchSize: 3} acc.add(makeBatch(2)) if acc.ready() { t.Error("should not be ready with 2 rows and batchSize=3") } acc.add(makeBatch(1)) if !acc.ready() { t.Error("should be ready with 3 rows and batchSize=3") } } func TestLoaderAccumulator_DrainPending_ReleasesWg(t *testing.T) { acc := &loaderAccumulator{batchSize: 5, pendingDone: 3} var wg sync.WaitGroup wg.Add(3) acc.drainPending(&wg) select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out: drainPending did not call Done() enough times") } } func TestLoaderAccumulator_DrainPending_ZeroPending(t *testing.T) { acc := &loaderAccumulator{batchSize: 5, pendingDone: 0} var wg sync.WaitGroup acc.drainPending(&wg) select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out") } } func TestSendLoadError_PlainError_WrappedAsNonFatal(t *testing.T) { ch := make(chan custom_errors.JobError, 2) var failedCount int32 result := sendLoadError(context.Background(), errors.New("db error"), rc(10), &failedCount, ch) if !result { t.Error("expected true (below threshold)") } if atomic.LoadInt32(&failedCount) != 1 { t.Errorf("expected failedCount=1, got %d", failedCount) } select { case e := <-ch: if e.ShouldCancelJob { t.Error("plain error should be wrapped as ShouldCancelJob=false") } default: t.Error("expected an error in the channel") } } func TestSendLoadError_JobError_PassesThrough(t *testing.T) { ch := make(chan custom_errors.JobError, 2) var failedCount int32 original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"} sendLoadError(context.Background(), original, rc(10), &failedCount, ch) select { case e := <-ch: if e.Msg != "custom msg" || e.ShouldCancelJob { t.Errorf("JobError should pass through unchanged, got %+v", e) } default: t.Error("expected an error in the channel") } } func TestSendLoadError_FatalJobError_BelowThreshold_ReturnsTrue(t *testing.T) { ch := make(chan custom_errors.JobError, 2) var failedCount int32 fatal := &custom_errors.JobError{ShouldCancelJob: true, Msg: "unique constraint"} result := sendLoadError(context.Background(), fatal, rc(10), &failedCount, ch) if !result { t.Error("below-threshold fatal error should return true (external cancel expected from JobErrorHandler)") } select { case e := <-ch: if !e.ShouldCancelJob { t.Error("fatal JobError should be forwarded with ShouldCancelJob=true") } default: t.Error("expected the fatal error in the channel") } } func TestSendLoadError_ThresholdExceeded_ReturnsFalse(t *testing.T) { ch := make(chan custom_errors.JobError, 2) var failedCount int32 result := sendLoadError(context.Background(), errors.New("db error"), rc(0), &failedCount, ch) if result { t.Error("expected false when threshold exceeded") } if len(ch) != 2 { t.Fatalf("expected 2 errors (batch error + fatal threshold error), got %d", len(ch)) } <-ch // batch error threshold := <-ch if !threshold.ShouldCancelJob { t.Error("second error should be the fatal threshold error (ShouldCancelJob=true)") } } func TestSendLoadError_AtThresholdBoundary(t *testing.T) { ch := make(chan custom_errors.JobError, 6) var failedCount int32 if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { t.Error("first failure: expected true (below threshold)") } if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { t.Error("second failure: expected true (at threshold, not exceeded)") } if sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) { t.Error("third failure: expected false (threshold exceeded)") } } func TestSendLoadError_ContextCancelled_ReturnsFalse(t *testing.T) { ch := make(chan custom_errors.JobError) var failedCount int32 ctx, cancel := context.WithCancel(context.Background()) cancel() result := sendLoadError(ctx, errors.New("db error"), rc(10), &failedCount, ch) if result { t.Error("expected false when context is cancelled") } if len(ch) != 0 { t.Error("no error should be sent when context is cancelled") } } func TestConsume_Passthrough_RowsLoaded(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(5), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 5 { t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded) } if db.callCount != 1 { t.Errorf("expected 1 SaveMassive call, got %d", db.callCount) } } func TestConsume_Passthrough_MultipleBatches_RowsAccumulate(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch, 3) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(3), &wg) sendBatch(chIn, makeBatch(2), &wg) sendBatch(chIn, makeBatch(4), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 9 { t.Errorf("expected rowsLoaded=9, got %d", rowsLoaded) } } func TestConsume_Passthrough_WgDoneBeforeErrorHandling(t *testing.T) { db := newMockDb(mockResult{err: dbError()}) gl := newLoader(db) chIn := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 2) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(2), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out: Done() was not called even though processing failed") } } func TestConsume_Passthrough_NonFatalError_Continues(t *testing.T) { db := newMockDb(mockResult{err: dbError()}) gl := newLoader(db) chIn := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 3) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(2), &wg) sendBatch(chIn, makeBatch(3), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 3 { t.Errorf("expected rowsLoaded=3 (only second batch succeeded), got %d", rowsLoaded) } if atomic.LoadInt32(&failedCount) != 1 { t.Errorf("expected failedCount=1, got %d", failedCount) } if len(chErr) == 0 { t.Error("expected at least one error in chErr for the failed batch") } } func TestConsume_Passthrough_ThresholdExceeded_Exits(t *testing.T) { db := newMockDb(mockResult{err: dbError()}) gl := newLoader(db) chIn := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 3) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(1), &wg) done := runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) select { case <-done: case <-time.After(testTimeout): t.Fatal("Consume did not exit after threshold exceeded") } select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out after threshold exit") } } func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch, 3) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(1), &wg) sendBatch(chIn, makeBatch(1), &wg) sendBatch(chIn, makeBatch(1), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(0), 3, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 3 { t.Errorf("expected rowsLoaded=3, got %d", rowsLoaded) } if db.callCount != 1 { t.Errorf("expected 1 SaveMassive call, got %d", db.callCount) } } func TestConsume_Accumulation_FlushOnClose(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(2), &wg) sendBatch(chIn, makeBatch(3), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 5 { t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded) } if db.callCount != 1 { t.Errorf("expected exactly 1 SaveMassive call (single flush on close), got %d", db.callCount) } } func TestConsume_Accumulation_RowsLoadedCorrect(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch, 5) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 for range 5 { sendBatch(chIn, makeBatch(2), &wg) } close(chIn) <-runConsume(context.Background(), gl, rc(0), 4, chIn, chErr, &wg, &rowsLoaded, &failedCount) wg.Wait() if rowsLoaded != 10 { t.Errorf("expected rowsLoaded=10, got %d", rowsLoaded) } if db.callCount != 3 { t.Errorf("expected 3 SaveMassive calls (2 threshold flushes + 1 on close), got %d", db.callCount) } } func TestConsume_Accumulation_WgBalanced_OnContextCancel(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 ctx, cancel := context.WithCancel(context.Background()) done := runConsume(ctx, gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount) sendBatch(chIn, makeBatch(1), &wg) sendBatch(chIn, makeBatch(1), &wg) cancel() select { case <-done: case <-time.After(testTimeout): t.Fatal("Consume did not exit after context cancellation") } select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out: drainPending did not release accumulated batches on cancel") } } func TestConsume_Accumulation_ErrorInFlush_WgStillBalanced(t *testing.T) { db := newMockDb(mockResult{err: dbError()}) gl := newLoader(db) chIn := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 3) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 sendBatch(chIn, makeBatch(1), &wg) sendBatch(chIn, makeBatch(1), &wg) close(chIn) <-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount) select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out: wg.Done() not called after flush error") } } func TestConsume_Accumulation_MultipleFlushes_NonFatalErrors(t *testing.T) { db := newMockDb(mockResult{err: dbError()}, mockResult{err: dbError()}) gl := newLoader(db) chIn := make(chan models.Batch, 4) chErr := make(chan custom_errors.JobError, 6) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 for range 4 { sendBatch(chIn, makeBatch(1), &wg) } close(chIn) <-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount) select { case <-waitWg(&wg): case <-time.After(testTimeout): t.Fatal("wg.Wait() timed out") } if atomic.LoadInt32(&failedCount) != 2 { t.Errorf("expected failedCount=2, got %d", failedCount) } if rowsLoaded != 0 { t.Errorf("expected rowsLoaded=0 (all batches failed), got %d", rowsLoaded) } } func TestConsume_EmptyInput_NoProcessing(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 close(chIn) done := runConsume(context.Background(), gl, rc(0), 5, chIn, chErr, &wg, &rowsLoaded, &failedCount) select { case <-done: case <-time.After(testTimeout): t.Fatal("Consume did not exit after empty input channel was closed") } if db.callCount != 0 { t.Errorf("expected no SaveMassive calls, got %d", db.callCount) } if rowsLoaded != 0 { t.Errorf("expected rowsLoaded=0, got %d", rowsLoaded) } wg.Wait() } func TestConsume_ContextCancellation_Exits(t *testing.T) { db := newMockDb() gl := newLoader(db) chIn := make(chan models.Batch) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup var rowsLoaded int64 var failedCount int32 ctx, cancel := context.WithCancel(context.Background()) done := runConsume(ctx, gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount) cancel() select { case <-done: case <-time.After(testTimeout): t.Fatal("Consume did not exit after context cancellation") } wg.Wait() }