17 Commits

Author SHA1 Message Date
f126d5bbd0 feat: implement exponential backoff strategy for error handling in extractor and loader processes; enhance retry configuration options 2026-04-12 20:35:29 -05:00
5633dc98d0 fix: enhance error handling in extractor and loader processes; ensure proper job error propagation and logging 2026-04-12 20:11:19 -05:00
01780b4b02 refactor: remove unused ColumnType and inspect-columns files; update migration job to use separate table analyzers for source and target databases 2026-04-12 19:16:14 -05:00
aded502ee4 feat: implement Postgres table analyzer with column type querying and metadata retrieval 2026-04-12 13:38:36 -05:00
4d3cd6e4cf feat: add MSSQL table analyzer and integrate partition range generation for improved data migration 2026-04-11 01:23:13 -05:00
7830ae862d refactor: rename batch-related variables and functions for consistency and clarity 2026-04-11 00:44:12 -05:00
955bc65ce9 refactor: rename Batch to Partition in error handling and processing functions for consistency 2026-04-11 00:32:50 -05:00
9eb9821daf refactor: rename Batch to Partition and update related types and channels for consistency 2026-04-11 00:09:28 -05:00
cd0e53b1d2 feat: implement MSSQL extractor, transformer, and Postgres loader for enhanced data migration 2026-04-10 23:39:37 -05:00
1be7018ba3 feat: refactor configuration to include source and target database types 2026-04-10 22:58:57 -05:00
a5b5a04feb feat: update extractor and transformer constructors to return Extractor interface 2026-04-10 20:42:16 -05:00
c1bae79f98 feat: implement Postgres loader and refactor migration job processing 2026-04-10 20:40:01 -05:00
053e6bd673 feat: add MSSQL extractor and transformer implementations for improved data migration 2026-04-10 19:59:44 -05:00
eb3c3bbfce feat: refactor error handling to use custom_errors.LoaderError for improved error management 2026-04-10 19:35:38 -05:00
9493a2d32f feat: refactor error handling to accept max retry attempts as a parameter for improved flexibility 2026-04-10 19:32:12 -05:00
d228a048b8 feat: update extractor error handling to use models.UnknownRowValues for improved type consistency 2026-04-10 19:29:07 -05:00
ca621352c9 feat: refactor models to improve type handling and enhance error management across migration processes 2026-04-10 19:27:27 -05:00
35 changed files with 1330 additions and 1566 deletions

View File

@@ -1,117 +0,0 @@
package main
import (
"context"
"database/sql"
"fmt"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/google/uuid"
)
type Batch struct {
Id uuid.UUID
ParentId uuid.UUID
LowerLimit int64
UpperLimit int64
IsLowerLimitInclusive bool
ShouldUseRange bool
RetryCounter int
}
func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, tableInfo config.SourceTableInfo) (int64, error) {
query := `
SELECT
SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
var rowsCount int64
err := db.QueryRowContext(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func calculateBatchesMssql(ctx context.Context, db *sql.DB, tableInfo config.SourceTableInfo, batchCount int64) ([]Batch, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM
(SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T
GROUP BY batch_id
ORDER BY batch_id`,
tableInfo.PrimaryKey,
tableInfo.PrimaryKey,
tableInfo.PrimaryKey,
tableInfo.PrimaryKey,
tableInfo.Schema,
tableInfo.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
rows, err := db.QueryContext(ctxTimeout, query, sql.Named("batchCount", batchCount))
if err != nil {
return nil, err
}
defer rows.Close()
batches := make([]Batch, 0, batchCount)
for rows.Next() {
batch := Batch{
Id: uuid.New(),
ShouldUseRange: true,
RetryCounter: 0,
IsLowerLimitInclusive: true,
}
if err := rows.Scan(&batch.LowerLimit, &batch.UpperLimit); err != nil {
return nil, err
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return batches, nil
}
func batchGeneratorMssql(ctx context.Context, db *sql.DB, tableInfo config.SourceTableInfo, rowsPerBatch int64) ([]Batch, error) {
rowsCount, err := estimateTotalRowsMssql(ctx, db, tableInfo)
if err != nil {
return nil, err
}
var batchCount int64 = 1
if rowsCount > rowsPerBatch {
batchCount = rowsCount / rowsPerBatch
} else {
return []Batch{{
Id: uuid.New(),
ShouldUseRange: false,
RetryCounter: 0,
}}, nil
}
batches, err := calculateBatchesMssql(ctx, db, tableInfo, batchCount)
if err != nil {
return nil, err
}
return batches, nil
}

View File

@@ -1,75 +0,0 @@
package main
import (
"fmt"
"strings"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
)
func buildExtractQueryMssql(sourceDbInfo config.SourceTableInfo, columns []ColumnType, includeRange bool, isMinInclusive bool) string {
var sbQuery strings.Builder
sbQuery.WriteString("SELECT ")
if len(columns) == 0 {
sbQuery.WriteString("*")
} else {
for i, col := range columns {
fmt.Fprintf(&sbQuery, "[%s]", col.name)
if col.unifiedType == "GEOMETRY" {
fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.name)
}
if i < len(columns)-1 {
sbQuery.WriteString(", ")
}
}
}
fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", sourceDbInfo.Schema, sourceDbInfo.Table)
if includeRange {
fmt.Fprintf(&sbQuery, " WHERE [%s]", sourceDbInfo.PrimaryKey)
if isMinInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", sourceDbInfo.PrimaryKey)
}
fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", sourceDbInfo.PrimaryKey)
return sbQuery.String()
}
func buildExtractQueryPostgres(sourceDbInfo config.SourceTableInfo, columns []ColumnType) string {
var sbColumns strings.Builder
if len(columns) == 0 {
sbColumns.WriteString("*")
} else {
for i, col := range columns {
if col.unifiedType == "GEOMETRY" {
sbColumns.WriteString(`ST_AsEWKB("`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`") AS "`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`"`)
} else {
sbColumns.WriteString(`"`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`"`)
}
if i < len(columns)-1 {
sbColumns.WriteString(", ")
}
}
}
return fmt.Sprintf(`SELECT %s FROM "%s"."%s" ORDER BY "%s" ASC`, sbColumns.String(), sourceDbInfo.Schema, sourceDbInfo.Table, sourceDbInfo.PrimaryKey)
}

View File

@@ -1,44 +0,0 @@
package main
type ColumnType struct {
name string
hasMaxLength bool
hasPrecisionScale bool
userType string
systemType string
unifiedType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (c *ColumnType) Name() string {
return c.name
}
func (c *ColumnType) UserType() string {
return c.userType
}
func (c *ColumnType) SystemType() string {
return c.systemType
}
func (c *ColumnType) Length() (length int64, ok bool) {
return c.maxLength, c.hasMaxLength
}
func (c *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return c.precision, c.scale, c.hasPrecisionScale
}
func (c *ColumnType) Nullable() bool {
return c.nullable
}
func (c *ColumnType) Type() string {
return c.unifiedType
}

View File

@@ -1,102 +0,0 @@
package main
import (
"context"
"fmt"
"sync"
"github.com/google/uuid"
)
type ExtractorError struct {
Batch
LastId int64
HasLastId bool
Msg string
}
func (e *ExtractorError) Error() string {
return e.Msg
}
const maxRetryAttempts = 3
func extractorErrorHandler(
ctx context.Context,
chErrorsIn <-chan ExtractorError,
chBatchesOut chan<- Batch,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case err, ok := <-chErrorsIn:
if !ok {
return
}
if err.RetryCounter >= maxRetryAttempts {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("batch %v reached max retries (%d)", err.Id, maxRetryAttempts),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
wgActiveBatches.Done()
continue
}
newBatch := err.Batch
newBatch.RetryCounter++
if err.HasLastId {
newBatch.ParentId = err.Id
newBatch.Id = uuid.New()
newBatch.LowerLimit = err.LastId
newBatch.IsLowerLimitInclusive = false
}
select {
case chBatchesOut <- newBatch:
case <-ctx.Done():
return
}
}
}
}
func ExtractorErrorFromLastRowMssql(lastRow UnknownRowValues, indexPrimaryKey int, batch *Batch, previousError error) ExtractorError {
lastIdRawValue := lastRow[indexPrimaryKey]
lastId, ok := ToInt64(lastIdRawValue)
if !ok {
currentBatch := *batch
currentBatch.RetryCounter = maxRetryAttempts
return ExtractorError{
Batch: currentBatch,
HasLastId: true,
Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()),
}
}
return ExtractorError{
Batch: *batch,
HasLastId: true,
LastId: lastId,
Msg: previousError.Error(),
}
}

View File

@@ -1,251 +0,0 @@
package main
import (
"context"
"database/sql"
"errors"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
type UnknownRowValues = []any
type Chunk struct {
Id uuid.UUID
BatchId uuid.UUID
Data []UnknownRowValues
RetryCounter int
}
func extractFromMssql(
ctx context.Context,
db *sql.DB,
tableInfo config.SourceTableInfo,
columns []ColumnType,
chunkSize int,
chBatchesIn <-chan Batch,
chChunksOut chan<- Chunk,
chErrorsOut chan<- ExtractorError,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
rowsRead *int64,
) {
indexPrimaryKey := slices.IndexFunc(columns, func(col ColumnType) bool {
return strings.EqualFold(col.name, tableInfo.PrimaryKey)
})
if indexPrimaryKey == -1 {
jobError := JobError{
ShouldCancelJob: true,
Msg: "Primary key not found in provided columns",
}
select {
case <-ctx.Done():
return
case chJobErrorsOut <- jobError:
}
return
}
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
return
}
if abort := processBatch(ctx, db, tableInfo, columns, chunkSize, batch, indexPrimaryKey, chChunksOut, chErrorsOut, wgActiveBatches, rowsRead); abort {
return
}
}
}
}
func processBatch(
ctx context.Context,
db *sql.DB,
tableInfo config.SourceTableInfo,
columns []ColumnType,
chunkSize int,
batch Batch,
indexPrimaryKey int,
chChunksOut chan<- Chunk,
chErrorsOut chan<- ExtractorError,
wgActiveBatches *sync.WaitGroup,
rowsRead *int64,
) (abort bool) {
query := buildExtractQueryMssql(tableInfo, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive)
log.Debug("Query used to extract data from mssql: ", query)
var queryArgs []any
if batch.ShouldUseRange {
queryArgs = append(queryArgs,
sql.Named("min", batch.LowerLimit),
sql.Named("max", batch.UpperLimit),
)
}
queryStartTime := time.Now()
rows, err := db.QueryContext(ctx, query, queryArgs...)
if err != nil {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
defer rows.Close()
log.Debugf("Query executed in %v", time.Since(queryStartTime))
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
totalRowsExtracted := 0
chunkStartTime := time.Now()
for rows.Next() {
values := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
if err := rows.Scan(scanArgs...); err != nil {
if len(rowsChunk) == 0 {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
lastRow := rowsChunk[len(rowsChunk)-1]
select {
case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err):
case <-ctx.Done():
return true
}
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
return false
}
rowsChunk = append(rowsChunk, values)
totalRowsExtracted++
if len(rowsChunk) >= chunkSize {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(chunkSize) / chunkDuration.Seconds()
log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
chunkStartTime = time.Now()
}
}
if err := rows.Err(); err != nil {
if errors.Is(err, ctx.Err()) {
return true
}
if len(rowsChunk) == 0 {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
lastRow := rowsChunk[len(rowsChunk)-1]
select {
case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err):
case <-ctx.Done():
return true
}
return false
}
if len(rowsChunk) > 0 {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds()
log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
}
wgActiveBatches.Done()
return false
}
func extractFromPostgres(ctx context.Context, tableInfo config.SourceTableInfo, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error {
query := buildExtractQueryPostgres(tableInfo, columns)
log.Debug("Query used to extract data from postgres: ", query)
rows, err := db.Query(ctx, query)
if err != nil {
return err
}
defer rows.Close()
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
for rows.Next() {
values, err := rows.Values()
if err != nil {
return err
}
rowsChunk = append(rowsChunk, values)
if len(rowsChunk) >= chunkSize {
out <- rowsChunk
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
log.Infof("Chunk send... %+v", tableInfo)
}
}
if len(rowsChunk) > 0 {
out <- rowsChunk
log.Infof("Chunk send... %+v", tableInfo)
}
return nil
}

View File

@@ -1,289 +0,0 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func GetUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "character": true, "text": true, "character varying": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
if maxLength != nil {
column.maxLength = *maxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
column.hasMaxLength = false
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
if precision != nil && scale != nil {
column.precision = *precision
column.scale = *scale
column.hasPrecisionScale = true
} else {
column.precision = -1
column.scale = -1
column.hasPrecisionScale = false
}
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesPostgres(db *pgxpool.Pool, tableInfo config.TargetTableInfo) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, tableInfo.Schema, tableInfo.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale))
}
return colTypes, nil
}
func MapMssqlColumn(column ColumnType) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
column.hasMaxLength = true
if column.systemType == "nvarchar" || column.systemType == "nchar" {
if column.maxLength > 0 {
column.maxLength = column.maxLength / 2
}
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = true
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesMssql(db *sql.DB, tableInfo config.SourceTableInfo) ([]ColumnType, error) {
query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
if strings.HasPrefix(column.name, "graph_id") && column.systemType == "bigint" {
continue
}
colTypes = append(colTypes, MapMssqlColumn(column))
}
return colTypes, nil
}
func GetColumnTypes(
sourceDb *sql.DB,
targetDb *pgxpool.Pool,
sourceTable config.SourceTableInfo,
targetTable config.TargetTableInfo,
) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, sourceTable)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, targetTable)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
}

View File

@@ -1,47 +0,0 @@
package main
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
)
type JobError struct {
ShouldCancelJob bool
Msg string
Prev error
}
func (e *JobError) Error() string {
if e.Prev != nil {
return fmt.Sprintf("%s: %v", e.Msg, e.Prev)
}
return e.Msg
}
func jobErrorHandler(ctx context.Context, chErrorsIn <-chan JobError) error {
for {
if ctx.Err() != nil {
return nil
}
select {
case <-ctx.Done():
return nil
case err, ok := <-chErrorsIn:
if !ok {
return nil
}
if err.ShouldCancelJob {
log.Error(err.Msg, " - ", err.Prev)
return &err
}
log.Error(err.Msg, " - ", err.Prev)
}
}
}

View File

@@ -1,65 +0,0 @@
package main
import (
"context"
"fmt"
"sync"
)
type LoaderError struct {
Chunk
Msg string
}
func (e *LoaderError) Error() string {
return e.Msg
}
func loaderErrorHandler(
ctx context.Context,
chErrorsIn <-chan LoaderError,
chChunksOut chan<- Chunk,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case err, ok := <-chErrorsIn:
if !ok {
return
}
if err.RetryCounter >= maxRetryAttempts {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("chunk %v reached max retries (%d)", err.Id, maxRetryAttempts),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
wgActiveChunks.Done()
continue
}
err.RetryCounter++
select {
case chChunksOut <- err.Chunk:
case <-ctx.Done():
return
}
}
}
}

View File

@@ -1,195 +0,0 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
mssql "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func loadRowsPostgres(
ctx context.Context,
db *pgxpool.Pool,
tableInfo config.TargetTableInfo,
columns []ColumnType,
chChunksIn <-chan Chunk,
chErrorsOut chan<- LoaderError,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
rowsLoaded *int64,
) {
tableId := pgx.Identifier{tableInfo.Schema, tableInfo.Table}
colNames := Map(columns, func(col ColumnType) string {
return col.name
})
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case chunk, ok := <-chChunksIn:
if !ok {
return
}
if abort := loadChunkPostgres(ctx, db, tableId, colNames, chunk, chErrorsOut, chJobErrorsOut, wgActiveChunks, rowsLoaded); abort {
return
}
}
}
}
func loadChunkPostgres(
ctx context.Context,
db *pgxpool.Pool,
identifier pgx.Identifier,
colNames []string,
chunk Chunk,
chErrorsOut chan<- LoaderError,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
rowsLoaded *int64,
) (abort bool) {
chunkStartTime := time.Now()
_, err := db.CopyFrom(
ctx,
identifier,
colNames,
pgx.CopyFromRows(chunk.Data),
)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" {
select {
case chJobErrorsOut <- JobError{
ShouldCancelJob: true,
Msg: fmt.Sprintf("Fatal error in table %s", identifier.Sanitize()),
Prev: err,
}:
case <-ctx.Done():
}
wgActiveChunks.Done()
return true
}
}
select {
case chErrorsOut <- LoaderError{Chunk: chunk, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(chunk.Data)) / chunkDuration.Seconds()
log.Infof("Loaded chunk: %d rows in %v (%.0f rows/sec)", len(chunk.Data), chunkDuration, rowsPerSec)
atomic.AddInt64(rowsLoaded, int64(len(chunk.Data)))
wgActiveChunks.Done()
return false
}
func loadRowsMssql(ctx context.Context, tableInfo config.TargetTableInfo, columns []ColumnType, db *sql.DB, in <-chan []UnknownRowValues) error {
chunkCount := 0
totalRowsLoaded := 0
for rows := range in {
chunkStartTime := time.Now()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
fullTableName := fmt.Sprintf("[%s].[%s]", tableInfo.Schema, tableInfo.Table)
colNames := Map(columns, func(col ColumnType) string {
return col.name
})
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, colNames...))
if err != nil {
tx.Rollback()
return fmt.Errorf("error preparing bulk copy statement: %w", err)
}
copyStartTime := time.Now()
for _, row := range rows {
_, err = stmt.ExecContext(ctx, row...)
if err != nil {
stmt.Close()
tx.Rollback()
return fmt.Errorf("error executing row insert: %w", err)
}
}
result, err := stmt.ExecContext(ctx)
if err != nil {
stmt.Close()
tx.Rollback()
return fmt.Errorf("error flushing bulk data: %w", err)
}
err = stmt.Close()
if err != nil {
tx.Rollback()
return fmt.Errorf("error closing statement: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("error committing transaction: %w", err)
}
rowsAffected, _ := result.RowsAffected()
chunkCount++
totalRowsLoaded += int(rowsAffected)
copyDuration := time.Since(copyStartTime)
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(rows)) / chunkDuration.Seconds()
log.Infof("Loaded chunk #%d (MSSQL): %d rows in %v (copy: %v, %.0f rows/sec) - Total: %d rows", chunkCount, len(rows), chunkDuration, copyDuration, rowsPerSec, totalRowsLoaded)
}
return nil
}
func Map[T any, V any](input []T, mapper func(T) V) []V {
result := make([]V, len(input))
for i, v := range input {
result[i] = mapper(v)
}
return result
}
func fakeLoader(tableInfo config.TargetTableInfo, columns []ColumnType, in <-chan [][]any) {
for rows := range in {
log.Debugf("Chunk received, loading data into...")
for i, rowValues := range rows {
if i%100 == 0 {
logSampleRow(tableInfo.Schema, tableInfo.Table, columns, rowValues, fmt.Sprintf("row %d", i))
}
}
}
}

View File

@@ -7,6 +7,10 @@ import (
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/transformers"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
)
@@ -87,11 +91,26 @@ func processMigrationJobs(
chJobs := make(chan config.Job, len(jobs))
var wgJobs sync.WaitGroup
sourceTableAnalyzer := table_analyzers.NewMssqlTableAnalyzer(sourceDb)
targetTableAnalyzer := table_analyzers.NewPostgresTableAnalyzer(targetDb)
extractor := extractors.NewMssqlExtractor(sourceDb)
transformer := transformers.NewMssqlTransformer()
loader := loaders.NewPostgresLoader(targetDb)
for i := range maxParallelWorkers {
wgJobs.Go(func() {
for job := range chJobs {
log.Infof("[worker %d] >>> Processing job: %s.%s <<<", i, job.SourceTable.Schema, job.SourceTable.Table)
res := processMigrationJob(ctx, sourceDb, targetDb, job)
res := processMigrationJob(
ctx,
sourceTableAnalyzer,
targetTableAnalyzer,
extractor,
transformer,
loader,
job,
)
chJobResults <- res
}
})

View File

@@ -2,24 +2,31 @@ package main
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
func processMigrationJob(
ctx context.Context,
sourceDb *sql.DB,
targetDb *pgxpool.Pool,
sourceTableAnalyzer etl.TableAnalyzer,
targetTableAnalyzer etl.TableAnalyzer,
extractor etl.Extractor,
transformer etl.Transformer,
loader etl.Loader,
job config.Job,
) JobResult {
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
result := JobResult{
JobName: job.Name,
StartTime: time.Now(),
@@ -27,59 +34,108 @@ func processMigrationJob(
var rowsRead, rowsLoaded, rowsFailed int64
sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job.SourceTable, job.TargetTable)
var wgQueryColumnTypes errgroup.Group
var sourceColTypes, targetColTypes []models.ColumnType
wgQueryColumnTypes.Go(func() error {
var err error
sourceColTypes, err = sourceTableAnalyzer.QueryColumnTypes(jobCtx, job.SourceTable.TableInfo)
if err != nil {
return err
}
return nil
})
wgQueryColumnTypes.Go(func() error {
var err error
targetColTypes, err = targetTableAnalyzer.QueryColumnTypes(jobCtx, job.TargetTable.TableInfo)
if err != nil {
return err
}
return nil
})
err := wgQueryColumnTypes.Wait()
if err != nil {
result.Error = err
return result
}
logColumnTypes(sourceColTypes, "Source col types")
logColumnTypes(targetColTypes, "Target col types")
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
batches, err := batchGeneratorMssql(jobCtx, sourceDb, job.SourceTable, job.RowsPerBatch)
partitions, err := table_analyzers.PartitionRangeGenerator(
jobCtx,
sourceTableAnalyzer,
job.SourceTable.TableInfo,
job.SourceTable.PrimaryKey,
job.RowsPerPartition,
)
if err != nil {
log.Error("Unexpected error calculating batch ranges: ", err)
}
chJobErrors := make(chan JobError, job.QueueSize)
chBatches := make(chan Batch, job.QueueSize)
chExtractorErrors := make(chan ExtractorError, job.QueueSize)
chChunksRaw := make(chan Chunk, job.QueueSize)
chChunksTransformed := make(chan Chunk, job.QueueSize)
chLoadersErrors := make(chan LoaderError, job.QueueSize)
chJobErrors := make(chan custom_errors.JobError, job.QueueSize)
chExtractorErrors := make(chan custom_errors.ExtractorError, job.QueueSize)
chLoadersErrors := make(chan custom_errors.LoaderError, job.QueueSize)
chPartitions := make(chan models.Partition, job.QueueSize)
chBatchesRaw := make(chan models.Batch, job.QueueSize)
chBatchesTransformed := make(chan models.Batch, job.QueueSize)
var wgActivePartitions sync.WaitGroup
var wgActiveBatches sync.WaitGroup
var wgActiveChunks sync.WaitGroup
var wgExtractors sync.WaitGroup
var wgTransformers sync.WaitGroup
var wgLoaders sync.WaitGroup
go func() {
if err := jobErrorHandler(jobCtx, chJobErrors); err != nil {
if err := custom_errors.JobErrorHandler(jobCtx, chJobErrors); err != nil {
log.Error("Fatal error received from JobErrorHandler, canceling job... - ", err)
cancel()
result.Error = err
}
}()
go extractorErrorHandler(jobCtx, chExtractorErrors, chBatches, chJobErrors, &wgActiveBatches)
go loaderErrorHandler(jobCtx, chLoadersErrors, chChunksTransformed, chJobErrors, &wgActiveChunks)
go custom_errors.ExtractorErrorHandler(
jobCtx,
job.Retry,
chExtractorErrors,
chPartitions,
chJobErrors,
&wgActivePartitions,
)
go custom_errors.LoaderErrorHandler(
jobCtx,
job.Retry,
chLoadersErrors,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
)
maxExtractors := min(job.MaxExtractors, len(batches))
maxExtractors := min(job.MaxExtractors, len(partitions))
log.Infof("Starting %d extractor(s)...", maxExtractors)
for range maxExtractors {
wgExtractors.Go(func() {
extractFromMssql(jobCtx, sourceDb, job.SourceTable, sourceColTypes, job.ChunkSize, chBatches, chChunksRaw, chExtractorErrors, chJobErrors, &wgActiveBatches, &rowsRead)
extractor.Exec(
jobCtx,
job.SourceTable,
sourceColTypes,
job.BatchSize,
chPartitions,
chBatchesRaw,
chExtractorErrors,
chJobErrors,
&wgActivePartitions,
&rowsRead,
)
})
}
wgActiveBatches.Add(len(batches))
wgActivePartitions.Add(len(partitions))
go func() {
for _, batch := range batches {
chBatches <- batch
for _, batch := range partitions {
chPartitions <- batch
}
}()
@@ -87,7 +143,14 @@ func processMigrationJob(
for range maxExtractors {
wgTransformers.Go(func() {
transformRowsMssql(jobCtx, sourceColTypes, chChunksRaw, chChunksTransformed, chJobErrors, &wgActiveChunks)
transformer.Exec(
jobCtx,
sourceColTypes,
chBatchesRaw,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
)
})
}
@@ -95,30 +158,53 @@ func processMigrationJob(
for range job.MaxLoaders {
wgLoaders.Go(func() {
loadRowsPostgres(jobCtx, targetDb, job.TargetTable, targetColTypes, chChunksTransformed, chLoadersErrors, chJobErrors, &wgActiveChunks, &rowsLoaded)
loader.Exec(
jobCtx,
job.TargetTable,
targetColTypes,
chBatchesTransformed,
chLoadersErrors,
chJobErrors,
&wgActiveBatches,
&rowsLoaded,
)
})
}
go func() {
wgActiveBatches.Wait()
close(chBatches)
log.Debugf("Waiting for goroutines (%v)", job.Name)
wgActivePartitions.Wait()
log.Debugf("wgActivePartitions is empty (%v)", job.Name)
close(chPartitions)
log.Debugf("chPartitions is closed (%v)", job.Name)
close(chExtractorErrors)
log.Debugf("chExtractorErrors is closed (%v)", job.Name)
wgExtractors.Wait()
close(chChunksRaw)
log.Debugf("wgExtractors is empty (%v)", job.Name)
close(chBatchesRaw)
log.Debugf("chBatchesRaw is closed (%v)", job.Name)
wgTransformers.Wait()
log.Debugf("wgTransformers is empty (%v)", job.Name)
wgActiveChunks.Wait()
close(chChunksTransformed)
wgActiveBatches.Wait()
log.Debugf("wgActiveBatches is empty (%v)", job.Name)
close(chBatchesTransformed)
log.Debugf("chBatchesTransformed is empty (%v)", job.Name)
close(chLoadersErrors)
log.Debugf("chLoadersErrors is empty (%v)", job.Name)
wgLoaders.Wait()
log.Debugf("wgLoaders is empty (%v)", job.Name)
cancel()
}()
log.Debugf("waiting for local context to be done (%v)", job.Name)
<-jobCtx.Done()
log.Debugf("local context done (%v)", job.Name)
if ctx.Err() != nil {
result.Error = ctx.Err()
@@ -131,24 +217,3 @@ func processMigrationJob(
return result
}
func logColumnTypes(columnTypes []ColumnType, label string) {
log.Debug(label)
for _, col := range columnTypes {
log.Debugf("%+v", col)
}
}
func logSampleRow(
schema string,
table string,
columns []ColumnType,
rowValues UnknownRowValues,
tag string,
) {
log.Infof("[%s.%s] Sample row: (%s)", schema, table, tag)
for i, col := range columns {
log.Infof("%s (%T): %v", col.Name(), rowValues[i], rowValues[i])
}
}

View File

@@ -1,149 +0,0 @@
package main
import (
"context"
"errors"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type transformerFunc func(any) (any, error)
type columnTransformPlan struct {
index int
fn transformerFunc
}
func transformRowsMssql(
ctx context.Context,
columns []ColumnType,
chChunksIn <-chan Chunk,
chChunksOut chan<- Chunk,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) {
transformationPlan := computeTransformationPlan(columns)
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case chunk, ok := <-chChunksIn:
if !ok {
return
}
if len(transformationPlan) == 0 {
select {
case chChunksOut <- chunk:
wgActiveChunks.Add(1)
continue
case <-ctx.Done():
return
}
}
chunkStartTime := time.Now()
err := processChunk(ctx, &chunk, transformationPlan)
if err != nil {
if errors.Is(err, ctx.Err()) {
return
}
select {
case chJobErrorsOut <- JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}:
case <-ctx.Done():
}
return
}
log.Infof("Transformed chunk %s: %d rows in %v", chunk.Id, len(chunk.Data), time.Since(chunkStartTime))
select {
case chChunksOut <- chunk:
case <-ctx.Done():
return
}
wgActiveChunks.Add(1)
}
}
}
func computeTransformationPlan(columns []ColumnType) []columnTransformPlan {
var plan []columnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "uniqueidentifier":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return mssqlUuidToBigEndian(b)
}
return v, nil
},
})
case "geometry", "geography":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return wkbToEwkbWithSrid(b, 4326)
}
return v, nil
},
})
case "datetime", "datetime2":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if t, ok := v.(time.Time); ok {
return ensureUTC(t), nil
}
return v, nil
},
})
}
}
return plan
}
const processChunkCtxCheck = 4096
func processChunk(ctx context.Context, chunk *Chunk, transformationPlan []columnTransformPlan) error {
for i, rowValues := range chunk.Data {
if i%processChunkCtxCheck == 0 {
if err := ctx.Err(); err != nil {
return err
}
}
for _, task := range transformationPlan {
val := rowValues[task.index]
if val == nil {
continue
}
transformed, err := task.fn(val)
if err != nil {
return err
}
rowValues[task.index] = transformed
}
}
return nil
}

View File

@@ -1,15 +1,20 @@
max_parallel_workers: 4
source_db_type: sqlserver
target_db_type: postgres
defaults:
max_extractors: 2
max_loaders: 4
queue_size: 8
chunk_size: 25000
chunks_per_batch: 8
batch_size: 25000
batches_per_partition: 8
truncate_target: true
truncate_method: TRUNCATE # TRUNCATE | DELETE
retry:
attempts: 3
base_delay_ms: 500
max_delay_ms: 10000
max_jitter_ms: 500
jobs:
- name: cartografia_manzana

2
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/microsoft/go-mssqldb v1.9.8
github.com/sirupsen/logrus v1.9.4
github.com/twpayne/go-geom v1.6.1
golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -23,7 +24,6 @@ require (
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
)

View File

@@ -9,9 +9,7 @@ import (
type appConfig struct {
SourceDbUrl string
SourceDbType string
TargetDbUrl string
TargetDbType string
}
func loadEnv() {
@@ -29,26 +27,14 @@ func getAppConfig() appConfig {
log.Fatal("SOURCE_DB_URL environment variable not set")
}
sourceDbType := os.Getenv("SOURCE_DB_TYPE")
if sourceDbType == "" {
log.Fatal("SOURCE_DB_TYPE environment variable not set")
}
targetDbUrl := os.Getenv("TARGET_DB_URL")
if targetDbUrl == "" {
log.Fatal("TARGET_DB_URL environment variable not set")
}
targetDbType := os.Getenv("TARGET_DB_TYPE")
if targetDbType == "" {
log.Fatal("TARGET_DB_TYPE environment variable not set")
}
return appConfig{
SourceDbUrl: sourceDbUrl,
SourceDbType: sourceDbType,
TargetDbUrl: targetDbUrl,
TargetDbType: targetDbType,
}
}

View File

@@ -9,28 +9,34 @@ import (
type RetryConfig struct {
Attempts int `yaml:"attempts"`
BaseDelayMs int `yaml:"base_delay_ms"`
MaxDelayMs int `yaml:"max_delay_ms"`
MaxJitterMs int `yaml:"max_jitter_ms"`
}
type JobConfig struct {
MaxExtractors int `yaml:"max_extractors"`
MaxLoaders int `yaml:"max_loaders"`
QueueSize int `yaml:"queue_size"`
ChunkSize int `yaml:"chunk_size"`
ChunksPerBatch int `yaml:"chunks_per_batch"`
RowsPerBatch int64
BatchSize int `yaml:"batch_size"`
BatchesPerPartition int `yaml:"batches_per_partition"`
TruncateTarget bool `yaml:"truncate_target"`
TruncateMethod string `yaml:"truncate_method"`
Retry RetryConfig `yaml:"retry"`
RowsPerPartition int64
}
type TableInfo struct {
Schema string `yaml:"schema"`
Table string `yaml:"table"`
}
type TargetTableInfo struct {
Schema string `yaml:"schema"`
Table string `yaml:"table"`
TableInfo `yaml:",inline"`
}
type SourceTableInfo struct {
Schema string `yaml:"schema"`
Table string `yaml:"table"`
TableInfo `yaml:",inline"`
PrimaryKey string `yaml:"primary_key"`
}
@@ -46,12 +52,16 @@ type Job struct {
type MigrationConfig struct {
MaxParallelWorkers int `yaml:"max_parallel_workers"`
SourceDbType string `yaml:"source_db_type"`
TargetDbType string `yaml:"target_db_type"`
Defaults JobConfig `yaml:"defaults"`
Jobs []Job `yaml:"jobs"`
}
type rawConfig struct {
MaxParallelWorkers int `yaml:"max_parallel_workers"`
SourceDbType string `yaml:"source_db_type"`
TargetDbType string `yaml:"target_db_type"`
Defaults JobConfig `yaml:"defaults"`
Jobs []yaml.Node `yaml:"jobs"`
}
@@ -64,7 +74,7 @@ func (c *MigrationConfig) UnmarshalYAML(value *yaml.Node) error {
c.MaxParallelWorkers = raw.MaxParallelWorkers
c.Defaults = raw.Defaults
c.Defaults.RowsPerBatch = int64(raw.Defaults.ChunkSize * raw.Defaults.ChunksPerBatch)
c.Defaults.RowsPerPartition = int64(raw.Defaults.BatchSize * raw.Defaults.BatchesPerPartition)
for _, node := range raw.Jobs {
job := Job{
@@ -75,7 +85,7 @@ func (c *MigrationConfig) UnmarshalYAML(value *yaml.Node) error {
return err
}
job.RowsPerBatch = int64(job.ChunkSize * job.ChunksPerBatch)
job.RowsPerPartition = int64(job.BatchSize * job.BatchesPerPartition)
c.Jobs = append(c.Jobs, job)
}

View File

@@ -0,0 +1,61 @@
package custom_errors
import (
"context"
"math/rand"
"time"
)
func computeBackoffDelay(retryCounter int, baseDelayMs int, maxDelayMs int, maxJitterMs int) time.Duration {
if retryCounter < 0 {
retryCounter = 0
}
delay := max(time.Duration(baseDelayMs)*time.Millisecond, 0)
maxDelay := time.Duration(maxDelayMs) * time.Millisecond
for i := 0; i < retryCounter; i++ {
if maxDelayMs > 0 && delay >= maxDelay {
delay = maxDelay
break
}
if delay == 0 {
break
}
delay *= 2
}
if maxDelayMs > 0 && delay > maxDelay {
delay = maxDelay
}
if maxJitterMs > 0 {
jitter := time.Duration(rand.Intn(maxJitterMs+1)) * time.Millisecond
delay += jitter
}
if delay < 0 {
delay = 0
}
return delay
}
func requeueWithBackoff(ctx context.Context, delay time.Duration, enqueue func()) {
if delay <= 0 {
enqueue()
return
}
go func() {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return
case <-timer.C:
enqueue()
}
}()
}

View File

@@ -5,12 +5,13 @@ import (
"fmt"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type ExtractorError struct {
Batch models.Batch
Partition models.Partition
LastId int64
HasLastId bool
Msg string
@@ -20,14 +21,13 @@ func (e *ExtractorError) Error() string {
return e.Msg
}
const maxRetryAttempts = 3
func ExtractorErrorHandler(
ctx context.Context,
retryConfig config.RetryConfig,
chErrorsIn <-chan ExtractorError,
chBatchesOut chan<- models.Batch,
chPartitionsOut chan<- models.Partition,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
wgActivePartitions *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
@@ -43,10 +43,11 @@ func ExtractorErrorHandler(
return
}
if err.Batch.RetryCounter >= maxRetryAttempts {
if err.Partition.RetryCounter >= retryConfig.Attempts {
wgActivePartitions.Done()
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("batch %v reached max retries (%d)", err.Batch.Id, maxRetryAttempts),
Msg: fmt.Sprintf("Partition %v reached max retries (%d)", err.Partition.Id, retryConfig.Attempts),
Prev: &err,
}
@@ -56,25 +57,45 @@ func ExtractorErrorHandler(
return
}
wgActiveBatches.Done()
continue
}
newBatch := err.Batch
newBatch.RetryCounter++
if err.HasLastId {
newBatch.ParentId = err.Batch.Id
newBatch.Id = uuid.New()
newBatch.LowerLimit = err.LastId
newBatch.IsLowerLimitInclusive = false
} else {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("Temporal error in partition %v (retries: %d)", err.Partition.Id, err.Partition.RetryCounter),
Prev: &err,
}
select {
case chBatchesOut <- newBatch:
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
}
newPartition := err.Partition
newPartition.RetryCounter++
delay := computeBackoffDelay(
newPartition.RetryCounter,
retryConfig.BaseDelayMs,
retryConfig.MaxDelayMs,
retryConfig.MaxJitterMs,
)
if err.HasLastId {
newPartition.ParentId = err.Partition.Id
newPartition.Id = uuid.New()
newPartition.LowerLimit = err.LastId
newPartition.IsLowerLimitInclusive = false
}
requeueWithBackoff(ctx, delay, func() {
select {
case chPartitionsOut <- newPartition:
case <-ctx.Done():
return
}
})
}
}
}

View File

@@ -37,11 +37,11 @@ func JobErrorHandler(ctx context.Context, chErrorsIn <-chan JobError) error {
}
if err.ShouldCancelJob {
log.Error(err.Msg, " - ", err.Prev)
log.Errorf("(Fatal job error) - %v - %v", err.Msg, err.Prev)
return &err
}
log.Error(err.Msg, " - ", err.Prev)
log.Errorf("%v - %v", err.Msg, err.Prev)
}
}
}

View File

@@ -0,0 +1,89 @@
package custom_errors
import (
"context"
"fmt"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type LoaderError struct {
Batch models.Batch
Msg string
}
func (e *LoaderError) Error() string {
return e.Msg
}
func LoaderErrorHandler(
ctx context.Context,
retryConfig config.RetryConfig,
chErrorsIn <-chan LoaderError,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case err, ok := <-chErrorsIn:
if !ok {
return
}
if err.Batch.RetryCounter >= retryConfig.Attempts {
wgActiveBatches.Done()
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("Batch %v reached max retries (%d)", err.Batch.Id, retryConfig.Attempts),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
continue
} else {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("Temporal error in batch %v (retries: %d)", err.Batch.Id, err.Batch.RetryCounter),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
}
err.Batch.RetryCounter++
delay := computeBackoffDelay(
err.Batch.RetryCounter,
retryConfig.BaseDelayMs,
retryConfig.MaxDelayMs,
retryConfig.MaxJitterMs,
)
requeueWithBackoff(ctx, delay, func() {
select {
case chBatchesOut <- err.Batch:
case <-ctx.Done():
return
}
})
}
}
}

View File

@@ -1,36 +0,0 @@
package extractor
import (
"context"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type Extractor interface {
ProcessBatch(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
batch models.Batch,
indexPrimaryKey int,
chChunksOut chan<- models.Chunk,
rowsRead *int64,
) error
Exec(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
chBatchesIn <-chan models.Batch,
chChunksOut chan<- models.Chunk,
chErrorsOut chan<- custom_errors.ExtractorError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
rowsRead *int64,
)
}

View File

@@ -1,4 +1,4 @@
package extractor
package extractors
import (
"context"
@@ -13,6 +13,7 @@ import (
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/convert"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
@@ -21,7 +22,7 @@ type MssqlExtractor struct {
db *sql.DB
}
func NewMssqlExtractor(db *sql.DB) *MssqlExtractor {
func NewMssqlExtractor(db *sql.DB) etl.Extractor {
return &MssqlExtractor{db: db}
}
@@ -69,20 +70,20 @@ func buildExtractQueryMssql(
return sbQuery.String()
}
func extractorErrorFromLastRowMssql(
func errorFromLastRow(
lastRow models.UnknownRowValues,
indexPrimaryKey int,
batch *models.Batch,
partition *models.Partition,
previousError error,
) *custom_errors.ExtractorError {
lastIdRawValue := lastRow[indexPrimaryKey]
lastId, ok := convert.ToInt64(lastIdRawValue)
if !ok {
currentBatch := *batch
currentBatch.RetryCounter = 3
currentPartition := *partition
currentPartition.RetryCounter = 3
return &custom_errors.ExtractorError{
Batch: currentBatch,
Partition: currentPartition,
HasLastId: true,
Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()),
}
@@ -90,78 +91,78 @@ func extractorErrorFromLastRowMssql(
}
return &custom_errors.ExtractorError{
Batch: *batch,
Partition: *partition,
HasLastId: true,
LastId: lastId,
Msg: previousError.Error(),
}
}
func (mssqlEx *MssqlExtractor) ProcessBatch(
func (mssqlEx *MssqlExtractor) ProcessPartition(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
batch models.Batch,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
chChunksOut chan<- models.Chunk,
chBatchesOut chan<- models.Batch,
rowsRead *int64,
) error {
query := buildExtractQueryMssql(tableInfo, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive)
query := buildExtractQueryMssql(tableInfo, columns, partition.ShouldUseRange, partition.IsLowerLimitInclusive)
var queryArgs []any
if batch.ShouldUseRange {
if partition.ShouldUseRange {
queryArgs = append(queryArgs,
sql.Named("min", batch.LowerLimit),
sql.Named("max", batch.UpperLimit),
sql.Named("min", partition.LowerLimit),
sql.Named("max", partition.UpperLimit),
)
}
rows, err := mssqlEx.db.QueryContext(ctx, query, queryArgs...)
if err != nil {
return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}
return &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()}
}
defer rows.Close()
rowsChunk := make([]models.UnknownRowValues, 0, chunkSize)
batchRows := make([]models.UnknownRowValues, 0, batchSize)
for rows.Next() {
values := make([]any, len(columns))
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range values {
scanArgs[i] = &values[i]
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
if err := rows.Scan(scanArgs...); err != nil {
if len(rowsChunk) == 0 {
return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}
if len(batchRows) == 0 {
return &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()}
}
lastRow := rowsChunk[len(rowsChunk)-1]
lastRow := batchRows[len(batchRows)-1]
select {
case chChunksOut <- models.Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}:
case <-ctx.Done():
return nil
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
atomic.AddInt64(rowsRead, int64(len(batchRows)))
return extractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err)
return errorFromLastRow(lastRow, indexPrimaryKey, &partition, err)
}
rowsChunk = append(rowsChunk, values)
batchRows = append(batchRows, rowValues)
if len(rowsChunk) >= chunkSize {
if len(batchRows) >= batchSize {
select {
case chChunksOut <- models.Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}:
case <-ctx.Done():
return nil
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
rowsChunk = make([]models.UnknownRowValues, 0, chunkSize)
atomic.AddInt64(rowsRead, int64(len(batchRows)))
batchRows = make([]models.UnknownRowValues, 0, batchSize)
}
}
@@ -170,22 +171,22 @@ func (mssqlEx *MssqlExtractor) ProcessBatch(
return ctx.Err()
}
if len(rowsChunk) == 0 {
return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}
if len(batchRows) == 0 {
return &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()}
}
lastRow := rowsChunk[len(rowsChunk)-1]
return extractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err)
lastRow := batchRows[len(batchRows)-1]
return errorFromLastRow(lastRow, indexPrimaryKey, &partition, err)
}
if len(rowsChunk) > 0 {
if len(batchRows) > 0 {
select {
case chChunksOut <- models.Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}:
case <-ctx.Done():
return nil
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
atomic.AddInt64(rowsRead, int64(len(batchRows)))
}
return nil
@@ -195,12 +196,12 @@ func (mssqlEx *MssqlExtractor) Exec(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
chBatchesIn <-chan models.Batch,
chChunksOut chan<- models.Chunk,
batchSize int,
chPartitionsIn <-chan models.Partition,
chBatchesOut chan<- models.Batch,
chErrorsOut chan<- custom_errors.ExtractorError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
wgActivePartitions *sync.WaitGroup,
rowsRead *int64,
) {
indexPrimaryKey := slices.IndexFunc(columns, func(col models.ColumnType) bool {
@@ -228,42 +229,49 @@ func (mssqlEx *MssqlExtractor) Exec(
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
case partition, ok := <-chPartitionsIn:
if !ok {
return
}
err := mssqlEx.ProcessBatch(
err := mssqlEx.ProcessPartition(
ctx,
tableInfo,
columns,
chunkSize,
batch,
batchSize,
partition,
indexPrimaryKey,
chChunksOut,
chBatchesOut,
rowsRead,
)
if err != nil {
var exError *custom_errors.ExtractorError
var jobError *custom_errors.JobError
if errors.As(err, &exError) {
select {
case <-ctx.Done():
return
case chErrorsOut <- *exError:
}
}
} else if errors.As(err, &jobError) {
select {
case <-ctx.Done():
return
case chJobErrorsOut <- custom_errors.JobError{ShouldCancelJob: false, Prev: err}:
case chJobErrorsOut <- *jobError:
}
} else {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.ExtractorError{Partition: partition, Msg: err.Error()}:
}
}
wgActiveBatches.Done()
continue
}
wgActivePartitions.Done()
}
}
}

View File

@@ -1,4 +1,4 @@
package extractor
package extractors
import (
"context"
@@ -10,6 +10,7 @@ import (
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
@@ -19,7 +20,7 @@ type PostgresExtractor struct {
db *pgxpool.Pool
}
func NewPostgresExtractor(pool *pgxpool.Pool) *PostgresExtractor {
func NewPostgresExtractor(pool *pgxpool.Pool) etl.Extractor {
return &PostgresExtractor{db: pool}
}
@@ -51,29 +52,29 @@ func buildExtractQueryPostgres(sourceDbInfo config.SourceTableInfo, columns []mo
return fmt.Sprintf(`SELECT %s FROM "%s"."%s" ORDER BY "%s" ASC`, sbColumns.String(), sourceDbInfo.Schema, sourceDbInfo.Table, sourceDbInfo.PrimaryKey)
}
func (postgresEx *PostgresExtractor) ProcessBatch(
func (postgresEx *PostgresExtractor) ProcessPartition(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
batch models.Batch,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
chChunksOut chan<- models.Chunk,
chBatchesOut chan<- models.Batch,
rowsRead *int64,
) error {
query := buildExtractQueryPostgres(tableInfo, columns)
if batch.ShouldUseRange {
if partition.ShouldUseRange {
return errors.New("Batch config not yet supported")
}
rows, err := postgresEx.db.Query(ctx, query)
if err != nil {
return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}
return &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()}
}
defer rows.Close()
rowsChunk := make([]models.UnknownRowValues, 0, chunkSize)
batchRows := make([]models.UnknownRowValues, 0, batchSize)
for rows.Next() {
values, err := rows.Values()
@@ -81,17 +82,17 @@ func (postgresEx *PostgresExtractor) ProcessBatch(
return errors.New("Unexpected error reading rows from source")
}
rowsChunk = append(rowsChunk, values)
batchRows = append(batchRows, values)
if len(rowsChunk) >= chunkSize {
if len(batchRows) >= batchSize {
select {
case chChunksOut <- models.Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}:
case <-ctx.Done():
return nil
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
rowsChunk = make([]models.UnknownRowValues, 0, chunkSize)
atomic.AddInt64(rowsRead, int64(len(batchRows)))
batchRows = make([]models.UnknownRowValues, 0, batchSize)
}
}
@@ -99,14 +100,14 @@ func (postgresEx *PostgresExtractor) ProcessBatch(
return errors.New("Unexpected error reading rows from source")
}
if len(rowsChunk) > 0 {
if len(batchRows) > 0 {
select {
case chChunksOut <- models.Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}:
case <-ctx.Done():
return nil
}
atomic.AddInt64(rowsRead, int64(len(rowsChunk)))
atomic.AddInt64(rowsRead, int64(len(batchRows)))
}
return nil
@@ -116,12 +117,12 @@ func (postgresEx *PostgresExtractor) Exec(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
chunkSize int,
chBatchesIn <-chan models.Batch,
chChunksOut chan<- models.Chunk,
batchSize int,
chPartitionsIn <-chan models.Partition,
chBatchesOut chan<- models.Batch,
chErrorsOut chan<- custom_errors.ExtractorError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
wgActivePartitions *sync.WaitGroup,
rowsRead *int64,
) {
}

View File

@@ -0,0 +1 @@
package extractors

View File

@@ -0,0 +1,128 @@
package loaders
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type PostgresLoader struct {
db *pgxpool.Pool
}
func NewPostgresLoader(pool *pgxpool.Pool) etl.Loader {
return &PostgresLoader{db: pool}
}
func mapSlice[T any, V any](input []T, mapper func(T) V) []V {
result := make([]V, len(input))
for i, v := range input {
result[i] = mapper(v)
}
return result
}
func (postgresLd *PostgresLoader) ProcessBatch(
ctx context.Context,
tableInfo config.TargetTableInfo,
colNames []string,
batch models.Batch,
) (int, error) {
tableId := pgx.Identifier{tableInfo.Schema, tableInfo.Table}
_, err := postgresLd.db.CopyFrom(
ctx,
tableId,
colNames,
pgx.CopyFromRows(batch.Rows),
)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" {
return 0, &custom_errors.JobError{
ShouldCancelJob: true,
Msg: fmt.Sprintf("Fatal error in table %s", tableId.Sanitize()),
Prev: err,
}
}
}
return 0, &custom_errors.LoaderError{Batch: batch, Msg: err.Error()}
}
return len(batch.Rows), nil
}
func (postgresLd *PostgresLoader) Exec(
ctx context.Context,
tableInfo config.TargetTableInfo,
columns []models.ColumnType,
chBatchesIn <-chan models.Batch,
chErrorsOut chan<- custom_errors.LoaderError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
rowsLoaded *int64,
) {
colNames := mapSlice(columns, func(col models.ColumnType) string {
return col.Name()
})
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
return
}
processedRows, err := postgresLd.ProcessBatch(ctx, tableInfo, colNames, batch)
if err != nil {
var ldError *custom_errors.LoaderError
var jobError *custom_errors.JobError
if errors.As(err, &ldError) {
select {
case <-ctx.Done():
return
case chErrorsOut <- *ldError:
}
} else if errors.As(err, &jobError) {
select {
case <-ctx.Done():
return
case chJobErrorsOut <- *jobError:
}
} else {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.LoaderError{Batch: batch, Msg: err.Error()}:
}
}
continue
}
wgActiveBatches.Done()
atomic.AddInt64(rowsLoaded, int64(processedRows))
}
}
}

View File

@@ -0,0 +1 @@
package loaders

View File

@@ -0,0 +1,40 @@
package table_analyzers
import (
"context"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
func PartitionRangeGenerator(
ctx context.Context,
tableAnalyzer etl.TableAnalyzer,
tableInfo config.TableInfo,
partitionColumn string,
rowsPerPartition int64,
) ([]models.Partition, error) {
rowsCount, err := tableAnalyzer.EstimateTotalRows(ctx, tableInfo)
if err != nil {
return nil, err
}
if rowsCount <= rowsPerPartition {
return []models.Partition{{
Id: uuid.New(),
ShouldUseRange: false,
RetryCounter: 0,
}}, nil
}
partitionsCount := rowsCount / rowsPerPartition
partitions, err := tableAnalyzer.CalculatePartitionRanges(ctx, tableInfo, partitionColumn, partitionsCount)
if err != nil {
return nil, err
}
return partitions, nil
}

View File

@@ -0,0 +1,249 @@
package table_analyzers
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type MssqlTableAnalyzer struct {
db *sql.DB
}
func NewMssqlTableAnalyzer(db *sql.DB) etl.TableAnalyzer {
return &MssqlTableAnalyzer{db: db}
}
const mssqlColumnMetadataQuery string = `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table AND c.name NOT LIKE 'graph_id%'
ORDER BY c.column_id;`
type rawColumnMssql struct {
name string
userType string
systemType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (ta *MssqlTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func (ta *MssqlTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnMssql) models.ColumnType {
const nullValue int64 = -1
stringTypes := map[string]bool{"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true}
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
if stringTypes[rawColumn.systemType] {
if rawColumn.systemType == "nvarchar" || rawColumn.systemType == "nchar" {
if rawColumn.maxLength > 0 {
rawColumn.maxLength = rawColumn.maxLength / 2
}
}
rawColumn.precision, rawColumn.scale = nullValue, nullValue
} else if decimalTypes[rawColumn.systemType] {
rawColumn.maxLength = nullValue
} else {
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
}
columnType := models.NewColumnType(
rawColumn.name,
rawColumn.maxLength != nullValue,
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
rawColumn.userType,
rawColumn.systemType,
ta.systemTypeToUnifiedType(rawColumn.systemType),
rawColumn.nullable,
rawColumn.maxLength,
rawColumn.precision,
rawColumn.scale,
)
return columnType
}
func (ta *MssqlTableAnalyzer) QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error) {
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
rows, err := ta.db.QueryContext(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table))
if err != nil {
return nil, err
}
defer rows.Close()
var columnTypes []models.ColumnType
for rows.Next() {
var rawColumn rawColumnMssql
if err := rows.Scan(
&rawColumn.name,
&rawColumn.userType,
&rawColumn.systemType,
&rawColumn.nullable,
&rawColumn.maxLength,
&rawColumn.precision,
&rawColumn.scale,
); err != nil {
return nil, err
}
columnTypes = append(columnTypes, ta.rawColumnToColumnType(rawColumn))
}
return columnTypes, nil
}
func (ta *MssqlTableAnalyzer) EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error) {
query := `
SELECT SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
var rowsCount int64
err := ta.db.QueryRowContext(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func (ta *MssqlTableAnalyzer) CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
) ([]models.Partition, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM (SELECT [%s], NTILE(@maxPartitions) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T
GROUP BY batch_id
ORDER BY batch_id`,
partitionColumn,
partitionColumn,
partitionColumn,
partitionColumn,
tableInfo.Schema,
tableInfo.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
rows, err := ta.db.QueryContext(ctxTimeout, query, sql.Named("maxPartitions", maxPartitions))
if err != nil {
return nil, err
}
defer rows.Close()
partitions := make([]models.Partition, 0, maxPartitions)
for rows.Next() {
partition := models.Partition{
Id: uuid.New(),
ShouldUseRange: true,
RetryCounter: 0,
IsLowerLimitInclusive: true,
}
if err := rows.Scan(&partition.LowerLimit, &partition.UpperLimit); err != nil {
return nil, err
}
partitions = append(partitions, partition)
}
if err := rows.Err(); err != nil {
return nil, err
}
return partitions, nil
}

View File

@@ -0,0 +1,174 @@
package table_analyzers
import (
"context"
"strings"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/jackc/pgx/v5/pgxpool"
)
type PostgresTableAnalyzer struct {
db *pgxpool.Pool
}
func NewPostgresTableAnalyzer(db *pgxpool.Pool) etl.TableAnalyzer {
return &PostgresTableAnalyzer{db: db}
}
const postgresColumnMetadataQuery string = `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
COALESCE(c.character_maximum_length, -1) AS max_length,
COALESCE(c.numeric_precision, -1) AS precision,
COALESCE(c.numeric_scale, -1) AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;`
type rawColumnPostgres struct {
name string
userType string
systemType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType {
const nullValue int64 = -1
stringTypes := map[string]bool{"varchar": true, "char": true, "text": true}
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
if stringTypes[rawColumn.systemType] {
rawColumn.precision, rawColumn.scale = nullValue, nullValue
} else if decimalTypes[rawColumn.systemType] {
rawColumn.maxLength = nullValue
} else {
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
}
return models.NewColumnType(
rawColumn.name,
rawColumn.maxLength != nullValue,
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
rawColumn.userType,
rawColumn.systemType,
ta.systemTypeToUnifiedType(rawColumn.systemType),
rawColumn.nullable,
rawColumn.maxLength,
rawColumn.precision,
rawColumn.scale,
)
}
func (ta *PostgresTableAnalyzer) QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error) {
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table)
if err != nil {
return nil, err
}
defer rows.Close()
var colTypes []models.ColumnType
for rows.Next() {
var column rawColumnPostgres
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, err
}
colTypes = append(colTypes, ta.rawColumnToColumnType(column))
}
return colTypes, nil
}
func (ta *PostgresTableAnalyzer) EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error) {
return 0, nil
}
func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
) ([]models.Partition, error) {
return []models.Partition{}, nil
}

View File

@@ -0,0 +1,150 @@
package transformers
import (
"context"
"errors"
"sync"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type MssqlTransformer struct{}
func NewMssqlTransformer() etl.Transformer {
return &MssqlTransformer{}
}
func computeTransformationPlan(columns []models.ColumnType) []etl.ColumnTransformPlan {
var plan []etl.ColumnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "uniqueidentifier":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return mssqlUuidToBigEndian(b)
}
return v, nil
},
})
case "geometry", "geography":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return wkbToEwkbWithSrid(b, 4326)
}
return v, nil
},
})
case "datetime", "datetime2":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if t, ok := v.(time.Time); ok {
return ensureUTC(t), nil
}
return v, nil
},
})
}
}
return plan
}
const processBatchCtxCheck = 4096
func (mssqlTr *MssqlTransformer) ProcessBatch(
ctx context.Context,
batch *models.Batch,
transformationPlan []etl.ColumnTransformPlan,
) error {
for i, rowValues := range batch.Rows {
if i%processBatchCtxCheck == 0 {
if err := ctx.Err(); err != nil {
return err
}
}
for _, task := range transformationPlan {
val := rowValues[task.Index]
if val == nil {
continue
}
transformed, err := task.Fn(val)
if err != nil {
return err
}
rowValues[task.Index] = transformed
}
}
return nil
}
func (mssqlTr *MssqlTransformer) Exec(
ctx context.Context,
columns []models.ColumnType,
chBatchesIn <-chan models.Batch,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
) {
transformationPlan := computeTransformationPlan(columns)
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
return
}
if len(transformationPlan) == 0 {
select {
case chBatchesOut <- batch:
wgActiveBatches.Add(1)
continue
case <-ctx.Done():
return
}
}
err := mssqlTr.ProcessBatch(ctx, &batch, transformationPlan)
if err != nil {
if errors.Is(err, ctx.Err()) {
return
}
select {
case chJobErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}:
case <-ctx.Done():
}
return
}
select {
case chBatchesOut <- batch:
case <-ctx.Done():
return
}
wgActiveBatches.Add(1)
}
}
}

View File

@@ -0,0 +1 @@
package transformers

View File

@@ -1,4 +1,4 @@
package main
package transformers
import (
"encoding/binary"

99
internal/app/etl/types.go Normal file
View File

@@ -0,0 +1,99 @@
package etl
import (
"context"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type Extractor interface {
ProcessPartition(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
chBatchesOut chan<- models.Batch,
rowsRead *int64,
) error
Exec(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
chPartitionsIn <-chan models.Partition,
chBatchesOut chan<- models.Batch,
chErrorsOut chan<- custom_errors.ExtractorError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActivePartitions *sync.WaitGroup,
rowsRead *int64,
)
}
type TransformerFunc func(any) (any, error)
type ColumnTransformPlan struct {
Index int
Fn TransformerFunc
}
type Transformer interface {
ProcessBatch(
ctx context.Context,
batch *models.Batch,
transformationPlan []ColumnTransformPlan,
) error
Exec(
ctx context.Context,
columns []models.ColumnType,
chBatchesIn <-chan models.Batch,
chBactchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
)
}
type Loader interface {
ProcessBatch(
ctx context.Context,
tableInfo config.TargetTableInfo,
colNames []string,
batch models.Batch,
) (int, error)
Exec(
ctx context.Context,
tableInfo config.TargetTableInfo,
columns []models.ColumnType,
chBatchesIn <-chan models.Batch,
chErrorsOut chan<- custom_errors.LoaderError,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
rowsLoaded *int64,
)
}
type TableAnalyzer interface {
QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error)
EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error)
CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
) ([]models.Partition, error)
}

View File

@@ -42,3 +42,29 @@ func (c *ColumnType) Nullable() bool {
func (c *ColumnType) Type() string {
return c.unifiedType
}
func NewColumnType(
name string,
hasMaxLength bool,
hasPrecisionScale bool,
userType string,
systemType string,
unifiedType string,
nullable bool,
maxLength int64,
precision int64,
scale int64,
) ColumnType {
return ColumnType{
name,
hasMaxLength,
hasPrecisionScale,
userType,
systemType,
unifiedType,
nullable,
maxLength,
precision,
scale,
}
}

View File

@@ -4,14 +4,14 @@ import "github.com/google/uuid"
type UnknownRowValues = []any
type Chunk struct {
type Batch struct {
Id uuid.UUID
BatchId uuid.UUID
Data []UnknownRowValues
PartitionId uuid.UUID
Rows []UnknownRowValues
RetryCounter int
}
type Batch struct {
type Partition struct {
Id uuid.UUID
ParentId uuid.UUID
LowerLimit int64