4 Commits

7 changed files with 1294 additions and 142 deletions

View File

@@ -27,7 +27,6 @@ func sendBatch(ctx context.Context, chBatchesOut chan<- models.Batch, batch mode
func flush(
ctx context.Context,
partition *models.Partition,
batchSize int,
batchRows []models.UnknownRowValues,
chBatchesOut chan<- models.Batch,
@@ -36,7 +35,7 @@ func flush(
return nil
}
batch := models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows}
batch := models.Batch{Id: uuid.New(), Rows: batchRows}
batchRows = make([]models.UnknownRowValues, 0, batchSize)
return sendBatch(ctx, chBatchesOut, batch)
}

View File

@@ -90,7 +90,7 @@ func (ex *GenericExtractor) ProcessPartition(
return rowsRead, err
}
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
@@ -102,7 +102,7 @@ func (ex *GenericExtractor) ProcessPartition(
batchRows = append(batchRows, rowValues)
if len(batchRows) >= batchSize {
// logrus.Debugf("Batch size reached, flushing batch with %v rows (rowsRead=%v)", len(batchRows), rowsRead)
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
// logrus.Warnf("Error flushing rows: %v", err)
return rowsRead, err
}
@@ -110,7 +110,7 @@ func (ex *GenericExtractor) ProcessPartition(
}
}
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}

View File

@@ -13,6 +13,62 @@ import (
"github.com/sirupsen/logrus"
)
type loaderAccumulator struct {
batchSize int
rows []models.UnknownRowValues
parents []models.BatchRef
pendingDone int
}
func (a *loaderAccumulator) add(batch models.Batch) {
a.rows = append(a.rows, batch.Rows...)
a.parents = append(a.parents, models.BatchRef{Id: batch.Id})
a.pendingDone++
}
func (a *loaderAccumulator) ready() bool {
return len(a.rows) >= a.batchSize
}
func (a *loaderAccumulator) drainPending(wg *sync.WaitGroup) {
for range a.pendingDone {
wg.Done()
}
}
func sendLoadError(
ctx context.Context,
err error,
retryConfig config.RetryConfig,
failedBatchesCount *int32,
chErrorsOut chan<- custom_errors.JobError,
) bool {
atomic.AddInt32(failedBatchesCount, 1)
var jobErr custom_errors.JobError
if je, ok := errors.AsType[*custom_errors.JobError](err); ok {
jobErr = *je
} else {
jobErr = custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}
}
select {
case <-ctx.Done():
return false
case chErrorsOut <- jobErr:
}
if atomic.LoadInt32(failedBatchesCount) > int32(retryConfig.MaxFailedBatchesLoad) {
select {
case <-ctx.Done():
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed batches (load) reached"}:
}
return false
}
return true
}
func (gl *GenericLoader) Consume(
ctx context.Context,
tableInfo config.TargetTableInfo,
@@ -29,58 +85,29 @@ func (gl *GenericLoader) Consume(
return col.Name()
})
var accRows []models.UnknownRowValues
var parentBatchesId []uuid.UUID
pendingDone := 0
defer func() {
for range pendingDone {
wgActiveBatches.Done()
}
}()
acc := &loaderAccumulator{batchSize: batchSize}
defer acc.drainPending(wgActiveBatches)
flush := func() bool {
if len(accRows) == 0 {
if len(acc.rows) == 0 {
return true
}
count := len(parentBatchesId)
count := len(acc.parents)
superBatch := models.Batch{
Id: uuid.New(),
ParentBatchesId: parentBatchesId,
Rows: accRows,
Id: uuid.New(),
ParentBatches: acc.parents,
Rows: acc.rows,
}
processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch)
for range count {
wgActiveBatches.Done()
}
pendingDone -= count
accRows = nil
parentBatchesId = nil
acc.pendingDone -= count
acc.rows = nil
acc.parents = nil
if err != nil {
atomic.AddInt32(failedBatchesCount, 1)
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
select {
case <-ctx.Done():
return false
case chErrorsOut <- *jobError:
}
} else {
select {
case <-ctx.Done():
return false
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}:
}
}
if atomic.LoadInt32(failedBatchesCount) > int32(retryConfig.MaxFailedBatchesLoad) {
select {
case <-ctx.Done():
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed batches (load) reached"}:
}
return false
}
return true
return sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut)
}
current := atomic.LoadInt64(rowsLoaded)
@@ -90,13 +117,10 @@ func (gl *GenericLoader) Consume(
}
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
flush()
@@ -106,45 +130,20 @@ func (gl *GenericLoader) Consume(
if batchSize <= 0 {
processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, batch)
wgActiveBatches.Done()
if err != nil {
atomic.AddInt32(failedBatchesCount, 1)
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
select {
case <-ctx.Done():
return
case chErrorsOut <- *jobError:
}
} else {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}:
}
}
if atomic.LoadInt32(failedBatchesCount) > int32(retryConfig.MaxFailedBatchesLoad) {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed batches (load) reached"}:
return
}
if !sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut) {
return
}
continue
}
current := atomic.LoadInt64(rowsLoaded)
logrus.Debugf("Rows loaded: +%v [current=%v] (%s.%s)", processedRows, current, tableInfo.Schema, tableInfo.Table)
atomic.AddInt64(rowsLoaded, int64(processedRows))
continue
}
pendingDone++
accRows = append(accRows, batch.Rows...)
parentBatchesId = append(parentBatchesId, batch.Id)
if len(accRows) >= batchSize {
acc.add(batch)
if acc.ready() {
if !flush() {
return
}

View File

@@ -0,0 +1,603 @@
package loaders
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
const testTimeout = 2 * time.Second
type mockResult struct {
err error
}
type mockDbWrapper struct {
mu sync.Mutex
callCount int
results []mockResult
}
func newMockDb(results ...mockResult) *mockDbWrapper {
return &mockDbWrapper{results: results}
}
func (m *mockDbWrapper) SaveMassive(_ context.Context, _ string, _ string, _ []string, rows [][]any) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
idx := m.callCount
m.callCount++
if idx < len(m.results) && m.results[idx].err != nil {
return 0, m.results[idx].err
}
return int64(len(rows)), nil
}
func (m *mockDbWrapper) Close() error { return nil }
func (m *mockDbWrapper) Connect(_ context.Context, _ string) error { return nil }
func (m *mockDbWrapper) Exec(_ context.Context, _ string, _ ...any) (dbwrapper.ExecResult, error) {
return dbwrapper.ExecResult{}, nil
}
func (m *mockDbWrapper) GetDialect() string { return "" }
func (m *mockDbWrapper) Query(_ context.Context, _ string, _ ...any) (dbwrapper.RowsResult, error) {
return nil, nil
}
func (m *mockDbWrapper) QueryRow(_ context.Context, _ string, _ ...any) dbwrapper.RowResult {
return nil
}
func (m *mockDbWrapper) QueryFromObject(_ context.Context, _ dbwrapper.ExtractionQuery) (dbwrapper.RowsResult, error) {
return nil, nil
}
func makeBatch(numRows int) models.Batch {
rows := make([]models.UnknownRowValues, numRows)
for i := range rows {
rows[i] = models.UnknownRowValues{i}
}
return models.Batch{Id: uuid.New(), Rows: rows}
}
func newLoader(db *mockDbWrapper) GenericLoader {
return GenericLoader{db: db}
}
func rc(maxFailed int) config.RetryConfig {
return config.RetryConfig{Attempts: 1, MaxFailedBatchesLoad: maxFailed}
}
func sendBatch(chIn chan<- models.Batch, batch models.Batch, wg *sync.WaitGroup) {
wg.Add(1)
chIn <- batch
}
func runConsume(
ctx context.Context,
gl GenericLoader,
retryConfig config.RetryConfig,
batchSize int,
chIn <-chan models.Batch,
chErr chan<- custom_errors.JobError,
wg *sync.WaitGroup,
rowsLoaded *int64,
failedCount *int32,
) <-chan struct{} {
done := make(chan struct{})
go func() {
gl.Consume(ctx, config.TargetTableInfo{}, nil, retryConfig, batchSize,
chIn, chErr, wg, rowsLoaded, failedCount)
close(done)
}()
return done
}
func waitWg(wg *sync.WaitGroup) <-chan struct{} {
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
return done
}
func dbError() error { return errors.New("connection reset by peer") }
func TestLoaderAccumulator_Add(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5}
b1 := makeBatch(2)
b2 := makeBatch(3)
acc.add(b1)
acc.add(b2)
if len(acc.rows) != 5 {
t.Errorf("expected 5 rows, got %d", len(acc.rows))
}
if len(acc.parents) != 2 {
t.Fatalf("expected 2 parents, got %d", len(acc.parents))
}
if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id {
t.Error("parent IDs do not match source batch IDs in order")
}
if acc.pendingDone != 2 {
t.Errorf("expected pendingDone=2, got %d", acc.pendingDone)
}
}
func TestLoaderAccumulator_Ready(t *testing.T) {
acc := &loaderAccumulator{batchSize: 3}
acc.add(makeBatch(2))
if acc.ready() {
t.Error("should not be ready with 2 rows and batchSize=3")
}
acc.add(makeBatch(1))
if !acc.ready() {
t.Error("should be ready with 3 rows and batchSize=3")
}
}
func TestLoaderAccumulator_DrainPending_ReleasesWg(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5, pendingDone: 3}
var wg sync.WaitGroup
wg.Add(3)
acc.drainPending(&wg)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: drainPending did not call Done() enough times")
}
}
func TestLoaderAccumulator_DrainPending_ZeroPending(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5, pendingDone: 0}
var wg sync.WaitGroup
acc.drainPending(&wg)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out")
}
}
func TestSendLoadError_PlainError_WrappedAsNonFatal(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
result := sendLoadError(context.Background(), errors.New("db error"), rc(10), &failedCount, ch)
if !result {
t.Error("expected true (below threshold)")
}
if atomic.LoadInt32(&failedCount) != 1 {
t.Errorf("expected failedCount=1, got %d", failedCount)
}
select {
case e := <-ch:
if e.ShouldCancelJob {
t.Error("plain error should be wrapped as ShouldCancelJob=false")
}
default:
t.Error("expected an error in the channel")
}
}
func TestSendLoadError_JobError_PassesThrough(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"}
sendLoadError(context.Background(), original, rc(10), &failedCount, ch)
select {
case e := <-ch:
if e.Msg != "custom msg" || e.ShouldCancelJob {
t.Errorf("JobError should pass through unchanged, got %+v", e)
}
default:
t.Error("expected an error in the channel")
}
}
func TestSendLoadError_FatalJobError_BelowThreshold_ReturnsTrue(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
fatal := &custom_errors.JobError{ShouldCancelJob: true, Msg: "unique constraint"}
result := sendLoadError(context.Background(), fatal, rc(10), &failedCount, ch)
if !result {
t.Error("below-threshold fatal error should return true (external cancel expected from JobErrorHandler)")
}
select {
case e := <-ch:
if !e.ShouldCancelJob {
t.Error("fatal JobError should be forwarded with ShouldCancelJob=true")
}
default:
t.Error("expected the fatal error in the channel")
}
}
func TestSendLoadError_ThresholdExceeded_ReturnsFalse(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
result := sendLoadError(context.Background(), errors.New("db error"), rc(0), &failedCount, ch)
if result {
t.Error("expected false when threshold exceeded")
}
if len(ch) != 2 {
t.Fatalf("expected 2 errors (batch error + fatal threshold error), got %d", len(ch))
}
<-ch // batch error
threshold := <-ch
if !threshold.ShouldCancelJob {
t.Error("second error should be the fatal threshold error (ShouldCancelJob=true)")
}
}
func TestSendLoadError_AtThresholdBoundary(t *testing.T) {
ch := make(chan custom_errors.JobError, 6)
var failedCount int32
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("first failure: expected true (below threshold)")
}
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("second failure: expected true (at threshold, not exceeded)")
}
if sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("third failure: expected false (threshold exceeded)")
}
}
func TestSendLoadError_ContextCancelled_ReturnsFalse(t *testing.T) {
ch := make(chan custom_errors.JobError)
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := sendLoadError(ctx, errors.New("db error"), rc(10), &failedCount, ch)
if result {
t.Error("expected false when context is cancelled")
}
if len(ch) != 0 {
t.Error("no error should be sent when context is cancelled")
}
}
func TestConsume_Passthrough_RowsLoaded(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(5), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 5 {
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
}
}
func TestConsume_Passthrough_MultipleBatches_RowsAccumulate(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(3), &wg)
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(4), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 9 {
t.Errorf("expected rowsLoaded=9, got %d", rowsLoaded)
}
}
func TestConsume_Passthrough_WgDoneBeforeErrorHandling(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 2)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: Done() was not called even though processing failed")
}
}
func TestConsume_Passthrough_NonFatalError_Continues(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(3), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 3 {
t.Errorf("expected rowsLoaded=3 (only second batch succeeded), got %d", rowsLoaded)
}
if atomic.LoadInt32(&failedCount) != 1 {
t.Errorf("expected failedCount=1, got %d", failedCount)
}
if len(chErr) == 0 {
t.Error("expected at least one error in chErr for the failed batch")
}
}
func TestConsume_Passthrough_ThresholdExceeded_Exits(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
done := runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after threshold exceeded")
}
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out after threshold exit")
}
}
func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 3, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 3 {
t.Errorf("expected rowsLoaded=3, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
}
}
func TestConsume_Accumulation_FlushOnClose(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(3), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 5 {
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected exactly 1 SaveMassive call (single flush on close), got %d", db.callCount)
}
}
func TestConsume_Accumulation_RowsLoadedCorrect(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 5)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
for range 5 {
sendBatch(chIn, makeBatch(2), &wg)
}
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 4, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 10 {
t.Errorf("expected rowsLoaded=10, got %d", rowsLoaded)
}
if db.callCount != 3 {
t.Errorf("expected 3 SaveMassive calls (2 threshold flushes + 1 on close), got %d", db.callCount)
}
}
func TestConsume_Accumulation_WgBalanced_OnContextCancel(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after context cancellation")
}
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: drainPending did not release accumulated batches on cancel")
}
}
func TestConsume_Accumulation_ErrorInFlush_WgStillBalanced(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: wg.Done() not called after flush error")
}
}
func TestConsume_Accumulation_MultipleFlushes_NonFatalErrors(t *testing.T) {
db := newMockDb(mockResult{err: dbError()}, mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 4)
chErr := make(chan custom_errors.JobError, 6)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
for range 4 {
sendBatch(chIn, makeBatch(1), &wg)
}
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out")
}
if atomic.LoadInt32(&failedCount) != 2 {
t.Errorf("expected failedCount=2, got %d", failedCount)
}
if rowsLoaded != 0 {
t.Errorf("expected rowsLoaded=0 (all batches failed), got %d", rowsLoaded)
}
}
func TestConsume_EmptyInput_NoProcessing(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
close(chIn)
done := runConsume(context.Background(), gl, rc(0), 5, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after empty input channel was closed")
}
if db.callCount != 0 {
t.Errorf("expected no SaveMassive calls, got %d", db.callCount)
}
if rowsLoaded != 0 {
t.Errorf("expected rowsLoaded=0, got %d", rowsLoaded)
}
wg.Wait()
}
func TestConsume_ContextCancellation_Exits(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after context cancellation")
}
wg.Wait()
}

View File

@@ -11,6 +11,58 @@ import (
"github.com/google/uuid"
)
type batchAccumulator struct {
batchSize int
rows []models.UnknownRowValues
parents []models.BatchRef
}
func (a *batchAccumulator) add(batch models.Batch) {
a.rows = append(a.rows, batch.Rows...)
a.parents = append(a.parents, models.BatchRef{Id: batch.Id})
}
func (a *batchAccumulator) ready() bool {
return len(a.rows) >= a.batchSize
}
func (a *batchAccumulator) flush(ctx context.Context, chOut chan<- models.Batch, wg *sync.WaitGroup) bool {
if len(a.rows) == 0 {
return true
}
out := models.Batch{
Id: uuid.New(),
ParentBatches: a.parents,
Rows: a.rows,
}
wg.Add(1)
select {
case chOut <- out:
case <-ctx.Done():
wg.Done()
return false
}
a.rows = nil
a.parents = nil
return true
}
func sendTransformError(ctx context.Context, err error, ch chan<- custom_errors.JobError) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
var jobErr custom_errors.JobError
if je, ok := errors.AsType[*custom_errors.JobError](err); ok {
jobErr = *je
} else {
jobErr = custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}
}
select {
case ch <- jobErr:
case <-ctx.Done():
}
}
func (mssqlTr *MssqlTransformer) Consume(
ctx context.Context,
columns []models.ColumnType,
@@ -25,90 +77,40 @@ func (mssqlTr *MssqlTransformer) Consume(
storagePlan := computeStorageTransformationPlan(ctx, mssqlTr.azureClient, mssqlTr.toStorage, columns, mssqlTr.sourceTable)
transformationPlan = append(transformationPlan, storagePlan...)
var accRows []models.UnknownRowValues
var parentBatchesId []uuid.UUID
var firstPartitionId uuid.UUID
flush := func() bool {
if len(accRows) == 0 {
return true
}
out := models.Batch{
Id: uuid.New(),
PartitionId: firstPartitionId,
ParentBatchesId: parentBatchesId,
Rows: accRows,
}
select {
case chBatchesOut <- out:
wgActiveBatches.Add(1)
case <-ctx.Done():
return false
}
accRows = nil
parentBatchesId = nil
firstPartitionId = uuid.Nil
return true
}
acc := &batchAccumulator{batchSize: batchSize}
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
flush()
acc.flush(ctx, chBatchesOut, wgActiveBatches)
return
}
if len(transformationPlan) > 0 {
err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig)
if err != nil {
if errors.Is(err, ctx.Err()) {
return
}
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
select {
case chJobErrorsOut <- *jobError:
case <-ctx.Done():
return
}
} else {
select {
case chJobErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}:
case <-ctx.Done():
return
}
}
if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil {
sendTransformError(ctx, err, chJobErrorsOut)
return
}
}
if batchSize <= 0 {
wgActiveBatches.Add(1)
select {
case chBatchesOut <- batch:
wgActiveBatches.Add(1)
case <-ctx.Done():
wgActiveBatches.Done()
return
}
continue
}
if len(parentBatchesId) == 0 {
firstPartitionId = batch.PartitionId
}
accRows = append(accRows, batch.Rows...)
parentBatchesId = append(parentBatchesId, batch.Id)
if len(accRows) >= batchSize {
if !flush() {
acc.add(batch)
if acc.ready() {
if !acc.flush(ctx, chBatchesOut, wgActiveBatches) {
return
}
}

View File

@@ -0,0 +1,545 @@
package transformers
import (
"context"
"errors"
"sync"
"testing"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
const testTimeout = 2 * time.Second
func makeBatch(numRows int) models.Batch {
rows := make([]models.UnknownRowValues, numRows)
for i := range rows {
rows[i] = models.UnknownRowValues{i}
}
return models.Batch{Id: uuid.New(), Rows: rows}
}
func noRetry() config.RetryConfig {
return config.RetryConfig{Attempts: 1}
}
func newTransformer() *MssqlTransformer {
return &MssqlTransformer{}
}
func uuidColumn() models.ColumnType {
return models.NewColumnType("col_uuid", false, false, "uniqueidentifier", "uniqueidentifier", "string", false, 0, 0, 0)
}
func runConsume(
ctx context.Context,
tr *MssqlTransformer,
columns []models.ColumnType,
batchSize int,
chIn <-chan models.Batch,
chOut chan<- models.Batch,
chErr chan<- custom_errors.JobError,
wg *sync.WaitGroup,
) <-chan struct{} {
done := make(chan struct{})
go func() {
tr.Consume(ctx, columns, noRetry(), batchSize, chIn, chOut, chErr, wg)
close(done)
}()
return done
}
func drainOut(chOut <-chan models.Batch, wg *sync.WaitGroup) []models.Batch {
var batches []models.Batch
for {
select {
case b := <-chOut:
batches = append(batches, b)
wg.Done()
default:
return batches
}
}
}
func TestBatchAccumulator_Add(t *testing.T) {
acc := &batchAccumulator{batchSize: 5}
b1 := makeBatch(2)
b2 := makeBatch(3)
acc.add(b1)
acc.add(b2)
if len(acc.rows) != 5 {
t.Errorf("expected 5 rows, got %d", len(acc.rows))
}
if len(acc.parents) != 2 {
t.Fatalf("expected 2 parents, got %d", len(acc.parents))
}
if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id {
t.Error("parent IDs do not match source batch IDs")
}
}
func TestBatchAccumulator_Ready(t *testing.T) {
acc := &batchAccumulator{batchSize: 3}
acc.add(makeBatch(2))
if acc.ready() {
t.Error("should not be ready with 2 rows and batchSize=3")
}
acc.add(makeBatch(1))
if !acc.ready() {
t.Error("should be ready with 3 rows and batchSize=3")
}
}
func TestBatchAccumulator_Flush_Empty(t *testing.T) {
acc := &batchAccumulator{batchSize: 5}
chOut := make(chan models.Batch, 1)
var wg sync.WaitGroup
if !acc.flush(context.Background(), chOut, &wg) {
t.Error("flush on empty accumulator should return true")
}
if len(chOut) != 0 {
t.Error("flush on empty accumulator should send nothing")
}
}
func TestBatchAccumulator_Flush_Success(t *testing.T) {
acc := &batchAccumulator{batchSize: 2}
b := makeBatch(2)
acc.add(b)
chOut := make(chan models.Batch, 1)
var wg sync.WaitGroup
if !acc.flush(context.Background(), chOut, &wg) {
t.Fatal("flush should return true on success")
}
select {
case out := <-chOut:
wg.Done()
if len(out.Rows) != 2 {
t.Errorf("expected 2 rows in flushed batch, got %d", len(out.Rows))
}
if len(out.ParentBatches) != 1 || out.ParentBatches[0].Id != b.Id {
t.Error("flushed batch should reference the source batch as parent")
}
default:
t.Error("expected a batch in chOut after flush")
}
if len(acc.rows) != 0 || len(acc.parents) != 0 {
t.Error("accumulator state should be reset after flush")
}
wg.Wait()
}
func TestBatchAccumulator_Flush_ContextCancelled(t *testing.T) {
acc := &batchAccumulator{batchSize: 2}
acc.add(makeBatch(2))
chOut := make(chan models.Batch)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
cancel()
if acc.flush(ctx, chOut, &wg) {
t.Error("flush should return false when context is cancelled")
}
wg.Wait()
}
func TestSendTransformError_PlainError(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
sendTransformError(context.Background(), errors.New("something broke"), ch)
select {
case e := <-ch:
if !e.ShouldCancelJob {
t.Error("plain error should produce ShouldCancelJob=true")
}
default:
t.Error("expected a job error in the channel")
}
}
func TestSendTransformError_JobError_Passthrough(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"}
sendTransformError(context.Background(), original, ch)
select {
case e := <-ch:
if e.ShouldCancelJob != false || e.Msg != "custom msg" {
t.Errorf("JobError should pass through unchanged, got %+v", e)
}
default:
t.Error("expected a job error in the channel")
}
}
func TestSendTransformError_ContextCancelled_Silent(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sendTransformError(ctx, context.Canceled, ch)
if len(ch) != 0 {
t.Error("context.Canceled should be silently dropped")
}
}
func TestSendTransformError_DeadlineExceeded_Silent(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sendTransformError(ctx, context.DeadlineExceeded, ch)
if len(ch) != 0 {
t.Error("context.DeadlineExceeded should be silently dropped")
}
}
func TestConsume_Passthrough_PreservesOriginalBatch(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := makeBatch(3)
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg)
select {
case got := <-chOut:
wg.Done()
if got.Id != batch.Id {
t.Error("passthrough should preserve the original batch ID")
}
if len(got.Rows) != 3 {
t.Errorf("expected 3 rows, got %d", len(got.Rows))
}
case <-time.After(testTimeout):
t.Fatal("timeout waiting for output batch")
}
<-done
wg.Wait()
}
func TestConsume_Passthrough_WaitGroupBalanced(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 3)
chOut := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 3 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 3 {
t.Errorf("expected 3 output batches, got %d", len(batches))
}
wg.Wait()
}
func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 3)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 3 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 3, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 accumulated batch, got %d", len(batches))
}
if len(batches[0].Rows) != 3 {
t.Errorf("expected 3 rows in accumulated batch, got %d", len(batches[0].Rows))
}
wg.Wait()
}
func TestConsume_Accumulation_FlushOnClose(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 2)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
chIn <- makeBatch(1)
chIn <- makeBatch(1)
close(chIn)
done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 batch flushed on close, got %d", len(batches))
}
if len(batches[0].Rows) != 2 {
t.Errorf("expected 2 rows, got %d", len(batches[0].Rows))
}
wg.Wait()
}
func TestConsume_Accumulation_TracksAllParentBatches(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 2)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
b1 := makeBatch(1)
b2 := makeBatch(1)
chIn <- b1
chIn <- b2
close(chIn)
done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 output batch, got %d", len(batches))
}
parents := batches[0].ParentBatches
if len(parents) != 2 {
t.Fatalf("expected 2 parent refs, got %d", len(parents))
}
if parents[0].Id != b1.Id || parents[1].Id != b2.Id {
t.Error("parent IDs should match source batch IDs in order")
}
wg.Wait()
}
func TestConsume_Accumulation_MultipleFlushes(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 5)
chOut := make(chan models.Batch, 5)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 5 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 2, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 3 {
t.Fatalf("expected 3 output batches (2+2+1 rows), got %d", len(batches))
}
totalRows := 0
for _, b := range batches {
totalRows += len(b.Rows)
}
if totalRows != 5 {
t.Errorf("expected 5 total rows across all batches, got %d", totalRows)
}
wg.Wait()
}
func TestConsume_EmptyInput_NoOutput(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
close(chIn)
done := runConsume(context.Background(), tr, nil, 5, chIn, chOut, chErr, &wg)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("timeout: Consume did not exit after empty input channel was closed")
}
if len(chOut) != 0 {
t.Error("expected no output for empty input")
}
wg.Wait()
}
func TestConsume_TransformError_SendsJobError(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}},
}
chIn <- batch
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
select {
case err := <-chErr:
if !err.ShouldCancelJob {
t.Error("transform error should set ShouldCancelJob=true")
}
case <-time.After(testTimeout):
t.Fatal("timeout: expected a job error from transform failure")
}
<-done
wg.Wait()
}
func TestConsume_TransformError_NoOutputForwarded(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}},
}
chIn <- batch
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
if len(chOut) != 0 {
t.Error("no batch should be forwarded when transformation fails")
}
wg.Wait()
}
func TestConsume_ContextCancellation_Exits(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, tr, nil, 0, chIn, chOut, chErr, &wg)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("timeout: Consume did not exit after context cancellation")
}
wg.Wait()
}
func TestConsume_Transform_DatetimeConvertedToUTC(t *testing.T) {
tr := newTransformer()
col := models.NewColumnType("col_dt", false, false, "datetime", "datetime", "timestamp", false, 0, 0, 0)
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
nonUTC := time.Date(2024, 1, 15, 12, 0, 0, 0, time.FixedZone("EST", -5*3600))
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{nonUTC}},
}
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
select {
case got := <-chOut:
wg.Done()
result, ok := got.Rows[0][0].(time.Time)
if !ok {
t.Fatal("expected time.Time in output row")
}
if result.Location() != time.UTC {
t.Errorf("expected UTC location after transform, got %v", result.Location())
}
default:
t.Error("expected an output batch")
}
wg.Wait()
}
func TestConsume_Transform_NilValueSkipped(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{nil}},
}
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
select {
case got := <-chOut:
wg.Done()
if got.Rows[0][0] != nil {
t.Error("nil value should pass through unchanged")
}
default:
t.Error("expected an output batch even when value is nil")
}
if len(chErr) != 0 {
t.Error("nil value should not produce an error")
}
wg.Wait()
}

View File

@@ -8,12 +8,16 @@ import (
type UnknownRowValues = []any
type BatchRef struct {
Id uuid.UUID
PartitionId uuid.UUID
}
type Batch struct {
Id uuid.UUID
PartitionId uuid.UUID
ParentBatchesId []uuid.UUID
Rows []UnknownRowValues
RetryCounter int
Id uuid.UUID
ParentBatches []BatchRef
Rows []UnknownRowValues
RetryCounter int
}
type PartitionRange struct {