diff --git a/internal/app/etl/transformers/consume_test.go b/internal/app/etl/transformers/consume_test.go new file mode 100644 index 0000000..15b565d --- /dev/null +++ b/internal/app/etl/transformers/consume_test.go @@ -0,0 +1,545 @@ +package transformers + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" + "github.com/google/uuid" +) + +const testTimeout = 2 * time.Second + +func makeBatch(numRows int) models.Batch { + rows := make([]models.UnknownRowValues, numRows) + for i := range rows { + rows[i] = models.UnknownRowValues{i} + } + return models.Batch{Id: uuid.New(), Rows: rows} +} + +func noRetry() config.RetryConfig { + return config.RetryConfig{Attempts: 1} +} + +func newTransformer() *MssqlTransformer { + return &MssqlTransformer{} +} + +func uuidColumn() models.ColumnType { + return models.NewColumnType("col_uuid", false, false, "uniqueidentifier", "uniqueidentifier", "string", false, 0, 0, 0) +} + +func runConsume( + ctx context.Context, + tr *MssqlTransformer, + columns []models.ColumnType, + batchSize int, + chIn <-chan models.Batch, + chOut chan<- models.Batch, + chErr chan<- custom_errors.JobError, + wg *sync.WaitGroup, +) <-chan struct{} { + done := make(chan struct{}) + go func() { + tr.Consume(ctx, columns, noRetry(), batchSize, chIn, chOut, chErr, wg) + close(done) + }() + return done +} + +func drainOut(chOut <-chan models.Batch, wg *sync.WaitGroup) []models.Batch { + var batches []models.Batch + for { + select { + case b := <-chOut: + batches = append(batches, b) + wg.Done() + default: + return batches + } + } +} +func TestBatchAccumulator_Add(t *testing.T) { + acc := &batchAccumulator{batchSize: 5} + b1 := makeBatch(2) + b2 := makeBatch(3) + + acc.add(b1) + acc.add(b2) + + if len(acc.rows) != 5 { + t.Errorf("expected 5 rows, got %d", len(acc.rows)) + } + if len(acc.parents) != 2 { + t.Fatalf("expected 2 parents, got %d", len(acc.parents)) + } + if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id { + t.Error("parent IDs do not match source batch IDs") + } +} + +func TestBatchAccumulator_Ready(t *testing.T) { + acc := &batchAccumulator{batchSize: 3} + acc.add(makeBatch(2)) + if acc.ready() { + t.Error("should not be ready with 2 rows and batchSize=3") + } + acc.add(makeBatch(1)) + if !acc.ready() { + t.Error("should be ready with 3 rows and batchSize=3") + } +} + +func TestBatchAccumulator_Flush_Empty(t *testing.T) { + acc := &batchAccumulator{batchSize: 5} + chOut := make(chan models.Batch, 1) + var wg sync.WaitGroup + + if !acc.flush(context.Background(), chOut, &wg) { + t.Error("flush on empty accumulator should return true") + } + if len(chOut) != 0 { + t.Error("flush on empty accumulator should send nothing") + } +} + +func TestBatchAccumulator_Flush_Success(t *testing.T) { + acc := &batchAccumulator{batchSize: 2} + b := makeBatch(2) + acc.add(b) + + chOut := make(chan models.Batch, 1) + var wg sync.WaitGroup + + if !acc.flush(context.Background(), chOut, &wg) { + t.Fatal("flush should return true on success") + } + + select { + case out := <-chOut: + wg.Done() + if len(out.Rows) != 2 { + t.Errorf("expected 2 rows in flushed batch, got %d", len(out.Rows)) + } + if len(out.ParentBatches) != 1 || out.ParentBatches[0].Id != b.Id { + t.Error("flushed batch should reference the source batch as parent") + } + default: + t.Error("expected a batch in chOut after flush") + } + + if len(acc.rows) != 0 || len(acc.parents) != 0 { + t.Error("accumulator state should be reset after flush") + } + wg.Wait() +} + +func TestBatchAccumulator_Flush_ContextCancelled(t *testing.T) { + acc := &batchAccumulator{batchSize: 2} + acc.add(makeBatch(2)) + + chOut := make(chan models.Batch) + var wg sync.WaitGroup + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if acc.flush(ctx, chOut, &wg) { + t.Error("flush should return false when context is cancelled") + } + + wg.Wait() +} + +func TestSendTransformError_PlainError(t *testing.T) { + ch := make(chan custom_errors.JobError, 1) + + sendTransformError(context.Background(), errors.New("something broke"), ch) + + select { + case e := <-ch: + if !e.ShouldCancelJob { + t.Error("plain error should produce ShouldCancelJob=true") + } + default: + t.Error("expected a job error in the channel") + } +} + +func TestSendTransformError_JobError_Passthrough(t *testing.T) { + ch := make(chan custom_errors.JobError, 1) + original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"} + + sendTransformError(context.Background(), original, ch) + + select { + case e := <-ch: + if e.ShouldCancelJob != false || e.Msg != "custom msg" { + t.Errorf("JobError should pass through unchanged, got %+v", e) + } + default: + t.Error("expected a job error in the channel") + } +} + +func TestSendTransformError_ContextCancelled_Silent(t *testing.T) { + ch := make(chan custom_errors.JobError, 1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + sendTransformError(ctx, context.Canceled, ch) + + if len(ch) != 0 { + t.Error("context.Canceled should be silently dropped") + } +} + +func TestSendTransformError_DeadlineExceeded_Silent(t *testing.T) { + ch := make(chan custom_errors.JobError, 1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + sendTransformError(ctx, context.DeadlineExceeded, ch) + + if len(ch) != 0 { + t.Error("context.DeadlineExceeded should be silently dropped") + } +} + +func TestConsume_Passthrough_PreservesOriginalBatch(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 1) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + batch := makeBatch(3) + chIn <- batch + close(chIn) + + done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg) + + select { + case got := <-chOut: + wg.Done() + if got.Id != batch.Id { + t.Error("passthrough should preserve the original batch ID") + } + if len(got.Rows) != 3 { + t.Errorf("expected 3 rows, got %d", len(got.Rows)) + } + case <-time.After(testTimeout): + t.Fatal("timeout waiting for output batch") + } + + <-done + wg.Wait() +} + +func TestConsume_Passthrough_WaitGroupBalanced(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 3) + chOut := make(chan models.Batch, 3) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + for range 3 { + chIn <- makeBatch(1) + } + close(chIn) + + done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg) + <-done + + batches := drainOut(chOut, &wg) + if len(batches) != 3 { + t.Errorf("expected 3 output batches, got %d", len(batches)) + } + + wg.Wait() +} + +func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 3) + chOut := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + for range 3 { + chIn <- makeBatch(1) + } + close(chIn) + + done := runConsume(context.Background(), tr, nil, 3, chIn, chOut, chErr, &wg) + <-done + + batches := drainOut(chOut, &wg) + if len(batches) != 1 { + t.Fatalf("expected 1 accumulated batch, got %d", len(batches)) + } + if len(batches[0].Rows) != 3 { + t.Errorf("expected 3 rows in accumulated batch, got %d", len(batches[0].Rows)) + } + wg.Wait() +} + +func TestConsume_Accumulation_FlushOnClose(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 2) + chOut := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + chIn <- makeBatch(1) + chIn <- makeBatch(1) + close(chIn) + + done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg) + <-done + + batches := drainOut(chOut, &wg) + if len(batches) != 1 { + t.Fatalf("expected 1 batch flushed on close, got %d", len(batches)) + } + if len(batches[0].Rows) != 2 { + t.Errorf("expected 2 rows, got %d", len(batches[0].Rows)) + } + wg.Wait() +} + +func TestConsume_Accumulation_TracksAllParentBatches(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 2) + chOut := make(chan models.Batch, 2) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + b1 := makeBatch(1) + b2 := makeBatch(1) + chIn <- b1 + chIn <- b2 + close(chIn) + + done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg) + <-done + + batches := drainOut(chOut, &wg) + if len(batches) != 1 { + t.Fatalf("expected 1 output batch, got %d", len(batches)) + } + parents := batches[0].ParentBatches + if len(parents) != 2 { + t.Fatalf("expected 2 parent refs, got %d", len(parents)) + } + if parents[0].Id != b1.Id || parents[1].Id != b2.Id { + t.Error("parent IDs should match source batch IDs in order") + } + wg.Wait() +} + +func TestConsume_Accumulation_MultipleFlushes(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch, 5) + chOut := make(chan models.Batch, 5) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + for range 5 { + chIn <- makeBatch(1) + } + close(chIn) + + done := runConsume(context.Background(), tr, nil, 2, chIn, chOut, chErr, &wg) + <-done + + batches := drainOut(chOut, &wg) + if len(batches) != 3 { + t.Fatalf("expected 3 output batches (2+2+1 rows), got %d", len(batches)) + } + totalRows := 0 + for _, b := range batches { + totalRows += len(b.Rows) + } + if totalRows != 5 { + t.Errorf("expected 5 total rows across all batches, got %d", totalRows) + } + wg.Wait() +} + +func TestConsume_EmptyInput_NoOutput(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + close(chIn) + + done := runConsume(context.Background(), tr, nil, 5, chIn, chOut, chErr, &wg) + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("timeout: Consume did not exit after empty input channel was closed") + } + + if len(chOut) != 0 { + t.Error("expected no output for empty input") + } + wg.Wait() +} + +func TestConsume_TransformError_SendsJobError(t *testing.T) { + tr := newTransformer() + col := uuidColumn() + + chIn := make(chan models.Batch, 1) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + batch := models.Batch{ + Id: uuid.New(), + Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}}, + } + chIn <- batch + + done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) + + select { + case err := <-chErr: + if !err.ShouldCancelJob { + t.Error("transform error should set ShouldCancelJob=true") + } + case <-time.After(testTimeout): + t.Fatal("timeout: expected a job error from transform failure") + } + + <-done + wg.Wait() +} + +func TestConsume_TransformError_NoOutputForwarded(t *testing.T) { + tr := newTransformer() + col := uuidColumn() + + chIn := make(chan models.Batch, 1) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + batch := models.Batch{ + Id: uuid.New(), + Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}}, + } + chIn <- batch + + done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) + <-done + + if len(chOut) != 0 { + t.Error("no batch should be forwarded when transformation fails") + } + wg.Wait() +} + +func TestConsume_ContextCancellation_Exits(t *testing.T) { + tr := newTransformer() + chIn := make(chan models.Batch) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + ctx, cancel := context.WithCancel(context.Background()) + done := runConsume(ctx, tr, nil, 0, chIn, chOut, chErr, &wg) + + cancel() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatal("timeout: Consume did not exit after context cancellation") + } + wg.Wait() +} + +func TestConsume_Transform_DatetimeConvertedToUTC(t *testing.T) { + tr := newTransformer() + col := models.NewColumnType("col_dt", false, false, "datetime", "datetime", "timestamp", false, 0, 0, 0) + + chIn := make(chan models.Batch, 1) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + nonUTC := time.Date(2024, 1, 15, 12, 0, 0, 0, time.FixedZone("EST", -5*3600)) + batch := models.Batch{ + Id: uuid.New(), + Rows: []models.UnknownRowValues{{nonUTC}}, + } + chIn <- batch + close(chIn) + + done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) + <-done + + select { + case got := <-chOut: + wg.Done() + result, ok := got.Rows[0][0].(time.Time) + if !ok { + t.Fatal("expected time.Time in output row") + } + if result.Location() != time.UTC { + t.Errorf("expected UTC location after transform, got %v", result.Location()) + } + default: + t.Error("expected an output batch") + } + + wg.Wait() +} + +func TestConsume_Transform_NilValueSkipped(t *testing.T) { + tr := newTransformer() + col := uuidColumn() + + chIn := make(chan models.Batch, 1) + chOut := make(chan models.Batch, 1) + chErr := make(chan custom_errors.JobError, 1) + var wg sync.WaitGroup + + batch := models.Batch{ + Id: uuid.New(), + Rows: []models.UnknownRowValues{{nil}}, + } + chIn <- batch + close(chIn) + + done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) + <-done + + select { + case got := <-chOut: + wg.Done() + if got.Rows[0][0] != nil { + t.Error("nil value should pass through unchanged") + } + default: + t.Error("expected an output batch even when value is nil") + } + + if len(chErr) != 0 { + t.Error("nil value should not produce an error") + } + wg.Wait() +}