refactor: add unit tests for loaderAccumulator and consume functions; enhance error handling and batch processing logic
This commit is contained in:
603
internal/app/etl/loaders/consume_test.go
Normal file
603
internal/app/etl/loaders/consume_test.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package loaders
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
|
||||
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const testTimeout = 2 * time.Second
|
||||
|
||||
type mockResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type mockDbWrapper struct {
|
||||
mu sync.Mutex
|
||||
callCount int
|
||||
results []mockResult
|
||||
}
|
||||
|
||||
func newMockDb(results ...mockResult) *mockDbWrapper {
|
||||
return &mockDbWrapper{results: results}
|
||||
}
|
||||
|
||||
func (m *mockDbWrapper) SaveMassive(_ context.Context, _ string, _ string, _ []string, rows [][]any) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
idx := m.callCount
|
||||
m.callCount++
|
||||
if idx < len(m.results) && m.results[idx].err != nil {
|
||||
return 0, m.results[idx].err
|
||||
}
|
||||
return int64(len(rows)), nil
|
||||
}
|
||||
|
||||
func (m *mockDbWrapper) Close() error { return nil }
|
||||
func (m *mockDbWrapper) Connect(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockDbWrapper) Exec(_ context.Context, _ string, _ ...any) (dbwrapper.ExecResult, error) {
|
||||
return dbwrapper.ExecResult{}, nil
|
||||
}
|
||||
func (m *mockDbWrapper) GetDialect() string { return "" }
|
||||
func (m *mockDbWrapper) Query(_ context.Context, _ string, _ ...any) (dbwrapper.RowsResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockDbWrapper) QueryRow(_ context.Context, _ string, _ ...any) dbwrapper.RowResult {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDbWrapper) QueryFromObject(_ context.Context, _ dbwrapper.ExtractionQuery) (dbwrapper.RowsResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func makeBatch(numRows int) models.Batch {
|
||||
rows := make([]models.UnknownRowValues, numRows)
|
||||
for i := range rows {
|
||||
rows[i] = models.UnknownRowValues{i}
|
||||
}
|
||||
return models.Batch{Id: uuid.New(), Rows: rows}
|
||||
}
|
||||
|
||||
func newLoader(db *mockDbWrapper) GenericLoader {
|
||||
return GenericLoader{db: db}
|
||||
}
|
||||
|
||||
func rc(maxFailed int) config.RetryConfig {
|
||||
return config.RetryConfig{Attempts: 1, MaxFailedBatchesLoad: maxFailed}
|
||||
}
|
||||
|
||||
func sendBatch(chIn chan<- models.Batch, batch models.Batch, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
chIn <- batch
|
||||
}
|
||||
|
||||
func runConsume(
|
||||
ctx context.Context,
|
||||
gl GenericLoader,
|
||||
retryConfig config.RetryConfig,
|
||||
batchSize int,
|
||||
chIn <-chan models.Batch,
|
||||
chErr chan<- custom_errors.JobError,
|
||||
wg *sync.WaitGroup,
|
||||
rowsLoaded *int64,
|
||||
failedCount *int32,
|
||||
) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
gl.Consume(ctx, config.TargetTableInfo{}, nil, retryConfig, batchSize,
|
||||
chIn, chErr, wg, rowsLoaded, failedCount)
|
||||
close(done)
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func waitWg(wg *sync.WaitGroup) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
return done
|
||||
}
|
||||
|
||||
func dbError() error { return errors.New("connection reset by peer") }
|
||||
|
||||
func TestLoaderAccumulator_Add(t *testing.T) {
|
||||
acc := &loaderAccumulator{batchSize: 5}
|
||||
b1 := makeBatch(2)
|
||||
b2 := makeBatch(3)
|
||||
|
||||
acc.add(b1)
|
||||
acc.add(b2)
|
||||
|
||||
if len(acc.rows) != 5 {
|
||||
t.Errorf("expected 5 rows, got %d", len(acc.rows))
|
||||
}
|
||||
if len(acc.parents) != 2 {
|
||||
t.Fatalf("expected 2 parents, got %d", len(acc.parents))
|
||||
}
|
||||
if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id {
|
||||
t.Error("parent IDs do not match source batch IDs in order")
|
||||
}
|
||||
if acc.pendingDone != 2 {
|
||||
t.Errorf("expected pendingDone=2, got %d", acc.pendingDone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoaderAccumulator_Ready(t *testing.T) {
|
||||
acc := &loaderAccumulator{batchSize: 3}
|
||||
acc.add(makeBatch(2))
|
||||
if acc.ready() {
|
||||
t.Error("should not be ready with 2 rows and batchSize=3")
|
||||
}
|
||||
acc.add(makeBatch(1))
|
||||
if !acc.ready() {
|
||||
t.Error("should be ready with 3 rows and batchSize=3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoaderAccumulator_DrainPending_ReleasesWg(t *testing.T) {
|
||||
acc := &loaderAccumulator{batchSize: 5, pendingDone: 3}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
acc.drainPending(&wg)
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out: drainPending did not call Done() enough times")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoaderAccumulator_DrainPending_ZeroPending(t *testing.T) {
|
||||
acc := &loaderAccumulator{batchSize: 5, pendingDone: 0}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
acc.drainPending(&wg)
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_PlainError_WrappedAsNonFatal(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError, 2)
|
||||
var failedCount int32
|
||||
|
||||
result := sendLoadError(context.Background(), errors.New("db error"), rc(10), &failedCount, ch)
|
||||
|
||||
if !result {
|
||||
t.Error("expected true (below threshold)")
|
||||
}
|
||||
if atomic.LoadInt32(&failedCount) != 1 {
|
||||
t.Errorf("expected failedCount=1, got %d", failedCount)
|
||||
}
|
||||
select {
|
||||
case e := <-ch:
|
||||
if e.ShouldCancelJob {
|
||||
t.Error("plain error should be wrapped as ShouldCancelJob=false")
|
||||
}
|
||||
default:
|
||||
t.Error("expected an error in the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_JobError_PassesThrough(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError, 2)
|
||||
var failedCount int32
|
||||
original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"}
|
||||
|
||||
sendLoadError(context.Background(), original, rc(10), &failedCount, ch)
|
||||
|
||||
select {
|
||||
case e := <-ch:
|
||||
if e.Msg != "custom msg" || e.ShouldCancelJob {
|
||||
t.Errorf("JobError should pass through unchanged, got %+v", e)
|
||||
}
|
||||
default:
|
||||
t.Error("expected an error in the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_FatalJobError_BelowThreshold_ReturnsTrue(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError, 2)
|
||||
var failedCount int32
|
||||
fatal := &custom_errors.JobError{ShouldCancelJob: true, Msg: "unique constraint"}
|
||||
|
||||
result := sendLoadError(context.Background(), fatal, rc(10), &failedCount, ch)
|
||||
|
||||
if !result {
|
||||
t.Error("below-threshold fatal error should return true (external cancel expected from JobErrorHandler)")
|
||||
}
|
||||
select {
|
||||
case e := <-ch:
|
||||
if !e.ShouldCancelJob {
|
||||
t.Error("fatal JobError should be forwarded with ShouldCancelJob=true")
|
||||
}
|
||||
default:
|
||||
t.Error("expected the fatal error in the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_ThresholdExceeded_ReturnsFalse(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError, 2)
|
||||
var failedCount int32
|
||||
|
||||
result := sendLoadError(context.Background(), errors.New("db error"), rc(0), &failedCount, ch)
|
||||
|
||||
if result {
|
||||
t.Error("expected false when threshold exceeded")
|
||||
}
|
||||
if len(ch) != 2 {
|
||||
t.Fatalf("expected 2 errors (batch error + fatal threshold error), got %d", len(ch))
|
||||
}
|
||||
<-ch // batch error
|
||||
threshold := <-ch
|
||||
if !threshold.ShouldCancelJob {
|
||||
t.Error("second error should be the fatal threshold error (ShouldCancelJob=true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_AtThresholdBoundary(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError, 6)
|
||||
var failedCount int32
|
||||
|
||||
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
|
||||
t.Error("first failure: expected true (below threshold)")
|
||||
}
|
||||
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
|
||||
t.Error("second failure: expected true (at threshold, not exceeded)")
|
||||
}
|
||||
if sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
|
||||
t.Error("third failure: expected false (threshold exceeded)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendLoadError_ContextCancelled_ReturnsFalse(t *testing.T) {
|
||||
ch := make(chan custom_errors.JobError)
|
||||
var failedCount int32
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
result := sendLoadError(ctx, errors.New("db error"), rc(10), &failedCount, ch)
|
||||
|
||||
if result {
|
||||
t.Error("expected false when context is cancelled")
|
||||
}
|
||||
if len(ch) != 0 {
|
||||
t.Error("no error should be sent when context is cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Passthrough_RowsLoaded(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 1)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(5), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 5 {
|
||||
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
|
||||
}
|
||||
if db.callCount != 1 {
|
||||
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Passthrough_MultipleBatches_RowsAccumulate(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 3)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(3), &wg)
|
||||
sendBatch(chIn, makeBatch(2), &wg)
|
||||
sendBatch(chIn, makeBatch(4), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 9 {
|
||||
t.Errorf("expected rowsLoaded=9, got %d", rowsLoaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Passthrough_WgDoneBeforeErrorHandling(t *testing.T) {
|
||||
db := newMockDb(mockResult{err: dbError()})
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 1)
|
||||
chErr := make(chan custom_errors.JobError, 2)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(2), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out: Done() was not called even though processing failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Passthrough_NonFatalError_Continues(t *testing.T) {
|
||||
db := newMockDb(mockResult{err: dbError()})
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 2)
|
||||
chErr := make(chan custom_errors.JobError, 3)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(2), &wg)
|
||||
sendBatch(chIn, makeBatch(3), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 3 {
|
||||
t.Errorf("expected rowsLoaded=3 (only second batch succeeded), got %d", rowsLoaded)
|
||||
}
|
||||
if atomic.LoadInt32(&failedCount) != 1 {
|
||||
t.Errorf("expected failedCount=1, got %d", failedCount)
|
||||
}
|
||||
if len(chErr) == 0 {
|
||||
t.Error("expected at least one error in chErr for the failed batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Passthrough_ThresholdExceeded_Exits(t *testing.T) {
|
||||
db := newMockDb(mockResult{err: dbError()})
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 1)
|
||||
chErr := make(chan custom_errors.JobError, 3)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
|
||||
done := runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("Consume did not exit after threshold exceeded")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out after threshold exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 3)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(0), 3, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 3 {
|
||||
t.Errorf("expected rowsLoaded=3, got %d", rowsLoaded)
|
||||
}
|
||||
if db.callCount != 1 {
|
||||
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_FlushOnClose(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 2)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(2), &wg)
|
||||
sendBatch(chIn, makeBatch(3), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 5 {
|
||||
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
|
||||
}
|
||||
if db.callCount != 1 {
|
||||
t.Errorf("expected exactly 1 SaveMassive call (single flush on close), got %d", db.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_RowsLoadedCorrect(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 5)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
for range 5 {
|
||||
sendBatch(chIn, makeBatch(2), &wg)
|
||||
}
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(0), 4, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
wg.Wait()
|
||||
|
||||
if rowsLoaded != 10 {
|
||||
t.Errorf("expected rowsLoaded=10, got %d", rowsLoaded)
|
||||
}
|
||||
if db.callCount != 3 {
|
||||
t.Errorf("expected 3 SaveMassive calls (2 threshold flushes + 1 on close), got %d", db.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_WgBalanced_OnContextCancel(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := runConsume(ctx, gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("Consume did not exit after context cancellation")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out: drainPending did not release accumulated batches on cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_ErrorInFlush_WgStillBalanced(t *testing.T) {
|
||||
db := newMockDb(mockResult{err: dbError()})
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 2)
|
||||
chErr := make(chan custom_errors.JobError, 3)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out: wg.Done() not called after flush error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_Accumulation_MultipleFlushes_NonFatalErrors(t *testing.T) {
|
||||
db := newMockDb(mockResult{err: dbError()}, mockResult{err: dbError()})
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch, 4)
|
||||
chErr := make(chan custom_errors.JobError, 6)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
for range 4 {
|
||||
sendBatch(chIn, makeBatch(1), &wg)
|
||||
}
|
||||
close(chIn)
|
||||
|
||||
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
select {
|
||||
case <-waitWg(&wg):
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("wg.Wait() timed out")
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&failedCount) != 2 {
|
||||
t.Errorf("expected failedCount=2, got %d", failedCount)
|
||||
}
|
||||
if rowsLoaded != 0 {
|
||||
t.Errorf("expected rowsLoaded=0 (all batches failed), got %d", rowsLoaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsume_EmptyInput_NoProcessing(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
close(chIn)
|
||||
|
||||
done := runConsume(context.Background(), gl, rc(0), 5, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("Consume did not exit after empty input channel was closed")
|
||||
}
|
||||
|
||||
if db.callCount != 0 {
|
||||
t.Errorf("expected no SaveMassive calls, got %d", db.callCount)
|
||||
}
|
||||
if rowsLoaded != 0 {
|
||||
t.Errorf("expected rowsLoaded=0, got %d", rowsLoaded)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConsume_ContextCancellation_Exits(t *testing.T) {
|
||||
db := newMockDb()
|
||||
gl := newLoader(db)
|
||||
chIn := make(chan models.Batch)
|
||||
chErr := make(chan custom_errors.JobError, 1)
|
||||
var wg sync.WaitGroup
|
||||
var rowsLoaded int64
|
||||
var failedCount int32
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := runConsume(ctx, gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatal("Consume did not exit after context cancellation")
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
Reference in New Issue
Block a user