package transformers import ( "context" "errors" "sync" "testing" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) const testTimeout = 2 * time.Second func makeBatch(numRows int) models.Batch { rows := make([]models.UnknownRowValues, numRows) for i := range rows { rows[i] = models.UnknownRowValues{i} } return models.Batch{Id: uuid.New(), Rows: rows} } func noRetry() config.RetryConfig { return config.RetryConfig{Attempts: 1} } func newTransformer() *MssqlTransformer { return &MssqlTransformer{} } func uuidColumn() models.ColumnType { return models.NewColumnType("col_uuid", false, false, "uniqueidentifier", "uniqueidentifier", "string", false, 0, 0, 0) } func runConsume( ctx context.Context, tr *MssqlTransformer, columns []models.ColumnType, batchSize int, chIn <-chan models.Batch, chOut chan<- models.Batch, chErr chan<- custom_errors.JobError, wg *sync.WaitGroup, ) <-chan struct{} { done := make(chan struct{}) go func() { tr.Consume(ctx, columns, noRetry(), batchSize, chIn, chOut, chErr, wg) close(done) }() return done } func drainOut(chOut <-chan models.Batch, wg *sync.WaitGroup) []models.Batch { var batches []models.Batch for { select { case b := <-chOut: batches = append(batches, b) wg.Done() default: return batches } } } func TestBatchAccumulator_Add(t *testing.T) { acc := &batchAccumulator{batchSize: 5} b1 := makeBatch(2) b2 := makeBatch(3) acc.add(b1) acc.add(b2) if len(acc.rows) != 5 { t.Errorf("expected 5 rows, got %d", len(acc.rows)) } if len(acc.parents) != 2 { t.Fatalf("expected 2 parents, got %d", len(acc.parents)) } if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id { t.Error("parent IDs do not match source batch IDs") } } func TestBatchAccumulator_Ready(t *testing.T) { acc := &batchAccumulator{batchSize: 3} acc.add(makeBatch(2)) if acc.ready() { t.Error("should not be ready with 2 rows and batchSize=3") } acc.add(makeBatch(1)) if !acc.ready() { t.Error("should be ready with 3 rows and batchSize=3") } } func TestBatchAccumulator_Flush_Empty(t *testing.T) { acc := &batchAccumulator{batchSize: 5} chOut := make(chan models.Batch, 1) var wg sync.WaitGroup if !acc.flush(context.Background(), chOut, &wg) { t.Error("flush on empty accumulator should return true") } if len(chOut) != 0 { t.Error("flush on empty accumulator should send nothing") } } func TestBatchAccumulator_Flush_Success(t *testing.T) { acc := &batchAccumulator{batchSize: 2} b := makeBatch(2) acc.add(b) chOut := make(chan models.Batch, 1) var wg sync.WaitGroup if !acc.flush(context.Background(), chOut, &wg) { t.Fatal("flush should return true on success") } select { case out := <-chOut: wg.Done() if len(out.Rows) != 2 { t.Errorf("expected 2 rows in flushed batch, got %d", len(out.Rows)) } if len(out.ParentBatches) != 1 || out.ParentBatches[0].Id != b.Id { t.Error("flushed batch should reference the source batch as parent") } default: t.Error("expected a batch in chOut after flush") } if len(acc.rows) != 0 || len(acc.parents) != 0 { t.Error("accumulator state should be reset after flush") } wg.Wait() } func TestBatchAccumulator_Flush_ContextCancelled(t *testing.T) { acc := &batchAccumulator{batchSize: 2} acc.add(makeBatch(2)) chOut := make(chan models.Batch) var wg sync.WaitGroup ctx, cancel := context.WithCancel(context.Background()) cancel() if acc.flush(ctx, chOut, &wg) { t.Error("flush should return false when context is cancelled") } wg.Wait() } func TestSendTransformError_PlainError(t *testing.T) { ch := make(chan custom_errors.JobError, 1) sendTransformError(context.Background(), errors.New("something broke"), ch) select { case e := <-ch: if !e.ShouldCancelJob { t.Error("plain error should produce ShouldCancelJob=true") } default: t.Error("expected a job error in the channel") } } func TestSendTransformError_JobError_Passthrough(t *testing.T) { ch := make(chan custom_errors.JobError, 1) original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"} sendTransformError(context.Background(), original, ch) select { case e := <-ch: if e.ShouldCancelJob != false || e.Msg != "custom msg" { t.Errorf("JobError should pass through unchanged, got %+v", e) } default: t.Error("expected a job error in the channel") } } func TestSendTransformError_ContextCancelled_Silent(t *testing.T) { ch := make(chan custom_errors.JobError, 1) ctx, cancel := context.WithCancel(context.Background()) cancel() sendTransformError(ctx, context.Canceled, ch) if len(ch) != 0 { t.Error("context.Canceled should be silently dropped") } } func TestSendTransformError_DeadlineExceeded_Silent(t *testing.T) { ch := make(chan custom_errors.JobError, 1) ctx, cancel := context.WithCancel(context.Background()) cancel() sendTransformError(ctx, context.DeadlineExceeded, ch) if len(ch) != 0 { t.Error("context.DeadlineExceeded should be silently dropped") } } func TestConsume_Passthrough_PreservesOriginalBatch(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 1) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup batch := makeBatch(3) chIn <- batch close(chIn) done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg) select { case got := <-chOut: wg.Done() if got.Id != batch.Id { t.Error("passthrough should preserve the original batch ID") } if len(got.Rows) != 3 { t.Errorf("expected 3 rows, got %d", len(got.Rows)) } case <-time.After(testTimeout): t.Fatal("timeout waiting for output batch") } <-done wg.Wait() } func TestConsume_Passthrough_WaitGroupBalanced(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 3) chOut := make(chan models.Batch, 3) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup for range 3 { chIn <- makeBatch(1) } close(chIn) done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg) <-done batches := drainOut(chOut, &wg) if len(batches) != 3 { t.Errorf("expected 3 output batches, got %d", len(batches)) } wg.Wait() } func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 3) chOut := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup for range 3 { chIn <- makeBatch(1) } close(chIn) done := runConsume(context.Background(), tr, nil, 3, chIn, chOut, chErr, &wg) <-done batches := drainOut(chOut, &wg) if len(batches) != 1 { t.Fatalf("expected 1 accumulated batch, got %d", len(batches)) } if len(batches[0].Rows) != 3 { t.Errorf("expected 3 rows in accumulated batch, got %d", len(batches[0].Rows)) } wg.Wait() } func TestConsume_Accumulation_FlushOnClose(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 2) chOut := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup chIn <- makeBatch(1) chIn <- makeBatch(1) close(chIn) done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg) <-done batches := drainOut(chOut, &wg) if len(batches) != 1 { t.Fatalf("expected 1 batch flushed on close, got %d", len(batches)) } if len(batches[0].Rows) != 2 { t.Errorf("expected 2 rows, got %d", len(batches[0].Rows)) } wg.Wait() } func TestConsume_Accumulation_TracksAllParentBatches(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 2) chOut := make(chan models.Batch, 2) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup b1 := makeBatch(1) b2 := makeBatch(1) chIn <- b1 chIn <- b2 close(chIn) done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg) <-done batches := drainOut(chOut, &wg) if len(batches) != 1 { t.Fatalf("expected 1 output batch, got %d", len(batches)) } parents := batches[0].ParentBatches if len(parents) != 2 { t.Fatalf("expected 2 parent refs, got %d", len(parents)) } if parents[0].Id != b1.Id || parents[1].Id != b2.Id { t.Error("parent IDs should match source batch IDs in order") } wg.Wait() } func TestConsume_Accumulation_MultipleFlushes(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch, 5) chOut := make(chan models.Batch, 5) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup for range 5 { chIn <- makeBatch(1) } close(chIn) done := runConsume(context.Background(), tr, nil, 2, chIn, chOut, chErr, &wg) <-done batches := drainOut(chOut, &wg) if len(batches) != 3 { t.Fatalf("expected 3 output batches (2+2+1 rows), got %d", len(batches)) } totalRows := 0 for _, b := range batches { totalRows += len(b.Rows) } if totalRows != 5 { t.Errorf("expected 5 total rows across all batches, got %d", totalRows) } wg.Wait() } func TestConsume_EmptyInput_NoOutput(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup close(chIn) done := runConsume(context.Background(), tr, nil, 5, chIn, chOut, chErr, &wg) select { case <-done: case <-time.After(testTimeout): t.Fatal("timeout: Consume did not exit after empty input channel was closed") } if len(chOut) != 0 { t.Error("expected no output for empty input") } wg.Wait() } func TestConsume_TransformError_SendsJobError(t *testing.T) { tr := newTransformer() col := uuidColumn() chIn := make(chan models.Batch, 1) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup batch := models.Batch{ Id: uuid.New(), Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}}, } chIn <- batch done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) select { case err := <-chErr: if !err.ShouldCancelJob { t.Error("transform error should set ShouldCancelJob=true") } case <-time.After(testTimeout): t.Fatal("timeout: expected a job error from transform failure") } <-done wg.Wait() } func TestConsume_TransformError_NoOutputForwarded(t *testing.T) { tr := newTransformer() col := uuidColumn() chIn := make(chan models.Batch, 1) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup batch := models.Batch{ Id: uuid.New(), Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}}, } chIn <- batch done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) <-done if len(chOut) != 0 { t.Error("no batch should be forwarded when transformation fails") } wg.Wait() } func TestConsume_ContextCancellation_Exits(t *testing.T) { tr := newTransformer() chIn := make(chan models.Batch) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup ctx, cancel := context.WithCancel(context.Background()) done := runConsume(ctx, tr, nil, 0, chIn, chOut, chErr, &wg) cancel() select { case <-done: case <-time.After(testTimeout): t.Fatal("timeout: Consume did not exit after context cancellation") } wg.Wait() } func TestConsume_Transform_DatetimeConvertedToUTC(t *testing.T) { tr := newTransformer() col := models.NewColumnType("col_dt", false, false, "datetime", "datetime", "timestamp", false, 0, 0, 0) chIn := make(chan models.Batch, 1) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup nonUTC := time.Date(2024, 1, 15, 12, 0, 0, 0, time.FixedZone("EST", -5*3600)) batch := models.Batch{ Id: uuid.New(), Rows: []models.UnknownRowValues{{nonUTC}}, } chIn <- batch close(chIn) done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) <-done select { case got := <-chOut: wg.Done() result, ok := got.Rows[0][0].(time.Time) if !ok { t.Fatal("expected time.Time in output row") } if result.Location() != time.UTC { t.Errorf("expected UTC location after transform, got %v", result.Location()) } default: t.Error("expected an output batch") } wg.Wait() } func TestConsume_Transform_NilValueSkipped(t *testing.T) { tr := newTransformer() col := uuidColumn() chIn := make(chan models.Batch, 1) chOut := make(chan models.Batch, 1) chErr := make(chan custom_errors.JobError, 1) var wg sync.WaitGroup batch := models.Batch{ Id: uuid.New(), Rows: []models.UnknownRowValues{{nil}}, } chIn <- batch close(chIn) done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg) <-done select { case got := <-chOut: wg.Done() if got.Rows[0][0] != nil { t.Error("nil value should pass through unchanged") } default: t.Error("expected an output batch even when value is nil") } if len(chErr) != 0 { t.Error("nil value should not produce an error") } wg.Wait() }