Files
go-migrate/cmd/go_migrate/process.go

260 lines
6.8 KiB
Go

package main
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/transformers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
const jobErrorsChannelSize int = 100
func buildTruncateQuery(targetDbType, schema, table, truncateMethod string) string {
if truncateMethod == "DELETE" {
if targetDbType == "postgres" {
return fmt.Sprintf(`DELETE FROM "%s"."%s"`, schema, table)
}
return fmt.Sprintf(`DELETE FROM [%s].[%s]`, schema, table)
}
if targetDbType == "postgres" {
return fmt.Sprintf(`TRUNCATE TABLE "%s"."%s"`, schema, table)
}
return fmt.Sprintf(`TRUNCATE TABLE [%s].[%s]`, schema, table)
}
func processMigrationJob(
ctx context.Context,
targetDbWrapper dbwrapper.DbWrapper,
sourceTableAnalyzer etl.TableAnalyzer,
targetTableAnalyzer etl.TableAnalyzer,
extractor extractors.GenericExtractor,
azureClient *azure.Client,
loader loaders.GenericLoader,
job config.Job,
sourceDbType string,
targetDbType string,
) models.JobResult {
var transformer etl.Transformer
if sourceDbType == "postgres" {
transformer = transformers.NewPostgresTransformer(job.SourceTable)
} else {
transformer = transformers.NewMssqlTransformer(job.ToStorage, job.SourceTable, azureClient)
}
localCtx, cancel := context.WithCancel(ctx)
defer cancel()
result := models.JobResult{
JobName: job.Name,
StartTime: time.Now(),
}
var wgQueryColumnTypes errgroup.Group
var sourceColTypes, targetColTypes []models.ColumnType
wgQueryColumnTypes.Go(func() error {
var err error
sourceColTypes, err = sourceTableAnalyzer.QueryColumnTypes(localCtx, job.SourceTable.TableInfo)
if err != nil {
return err
}
return nil
})
wgQueryColumnTypes.Go(func() error {
var err error
targetColTypes, err = targetTableAnalyzer.QueryColumnTypes(localCtx, job.TargetTable.TableInfo)
if err != nil {
return err
}
return nil
})
err := wgQueryColumnTypes.Wait()
if err != nil {
result.Error = err
return result
}
preSqlQueries := job.TargetTable.PreSQL
if job.TruncateTarget {
truncateQuery := buildTruncateQuery(targetDbType, job.TargetTable.Schema, job.TargetTable.Table, job.TruncateMethod)
preSqlQueries = append([]string{truncateQuery}, job.TargetTable.PreSQL...)
}
for _, query := range preSqlQueries {
if _, err := targetDbWrapper.Exec(localCtx, query); err != nil {
result.Error = err
return result
}
}
partitions, err := table_analyzers.PartitionRangeGenerator(
localCtx,
sourceTableAnalyzer,
job.SourceTable.TableInfo,
job.SourceTable.PrimaryKey,
job.PartitionCalculationStrategy,
job.RowsPerPartition,
job.Range,
)
if err != nil {
log.Error("Unexpected error calculating batch ranges: ", err)
}
chJobErrors := make(chan custom_errors.JobError, jobErrorsChannelSize)
chPartitions := make(chan models.Partition)
chBatchesRaw := make(chan models.Batch, job.ExtractorQueueSize)
chBatchesTransformed := make(chan models.Batch, job.TransformerQueueSize)
var wgActivePartitions, wgActiveBatches, wgExtractors, wgTransformers, wgLoaders sync.WaitGroup
var rowsRead, rowsLoaded, rowsFailed int64
var failedPartitionsCount, failedBatchesLoadCount int32
go func() {
if err := custom_errors.JobErrorHandler(localCtx, chJobErrors); err != nil {
log.Error("Fatal error received from JobErrorHandler, canceling job... - ", err)
cancel()
result.Error = err
}
}()
maxExtractors := min(job.MaxExtractors, len(partitions))
log.Infof("Starting %d extractor(s)... (%v)", maxExtractors, job.Name)
for range maxExtractors {
wgExtractors.Go(func() {
extractor.Consume(
localCtx,
job.SourceTable,
sourceColTypes,
job.ExtractorBatchSize,
job.Retry,
chPartitions,
chBatchesRaw,
chJobErrors,
&wgActivePartitions,
&rowsRead,
&failedPartitionsCount,
job.SourceTable.FromJsonColumns,
)
})
}
wgActivePartitions.Add(len(partitions))
go func() {
for _, batch := range partitions {
chPartitions <- batch
}
}()
log.Infof("Starting %d transformer(s)... (%v)", maxExtractors, job.Name)
for range maxExtractors {
wgTransformers.Go(func() {
transformer.Consume(
localCtx,
sourceColTypes,
job.Retry,
job.TransformerBatchSize,
chBatchesRaw,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
)
})
}
log.Infof("Starting %d loader(s)... (%v)", job.MaxLoaders, job.Name)
for range job.MaxLoaders {
wgLoaders.Go(func() {
loader.Consume(
localCtx,
job.TargetTable,
targetColTypes,
job.Retry,
job.LoaderBatchSize,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
&rowsLoaded,
&failedBatchesLoadCount,
)
})
}
go func() {
// log.Debugf("Waiting for goroutines (%v)", job.Name)
wgActivePartitions.Wait()
// log.Debugf("wgActivePartitions is empty (%v)", job.Name)
close(chPartitions)
// log.Debugf("chPartitions is closed (%v)", job.Name)
wgExtractors.Wait()
// log.Debugf("wgExtractors is empty (%v)", job.Name)
close(chBatchesRaw)
// log.Debugf("chBatchesRaw is closed (%v)", job.Name)
wgTransformers.Wait()
// log.Debugf("wgTransformers is empty (%v)", job.Name)
close(chBatchesTransformed)
// log.Debugf("chBatchesTransformed is closed (%v)", job.Name)
wgActiveBatches.Wait()
// log.Debugf("wgActiveBatches is empty (%v)", job.Name)
wgLoaders.Wait()
// log.Debugf("wgLoaders is empty (%v)", job.Name)
cancel()
}()
for _, query := range job.TargetTable.PostSQL {
if _, err := targetDbWrapper.Exec(localCtx, query); err != nil {
result.Error = err
return result
}
}
// log.Debugf("waiting for local context to be done (%v)", job.Name)
<-localCtx.Done()
// log.Debugf("local context done (%v)", job.Name)
if ctx.Err() != nil {
result.Error = ctx.Err()
}
result.Duration = time.Since(result.StartTime)
result.RowsRead = atomic.LoadInt64(&rowsRead)
result.RowsLoaded = atomic.LoadInt64(&rowsLoaded)
result.RowsFailed = atomic.LoadInt64(&rowsFailed)
if result.RowsRead != result.RowsLoaded {
result.Error = fmt.Errorf("Row count mismatch: extracted %d rows but loaded %d rows (failed: %d)", result.RowsRead, result.RowsLoaded, result.RowsFailed)
}
if result.RowsRead == 0 {
log.Warnf("No rows extracted from (%v)", job.Name)
}
return result
}