refactor: implement bidirectional transformation support with PostgreSQL integration

This commit is contained in:
2026-05-29 12:17:16 -05:00
parent 537b7fbd28
commit 4ba26092a9
7 changed files with 430 additions and 6 deletions

View File

@@ -2,6 +2,7 @@ package table_analyzers
import (
"context"
"fmt"
"strings"
"time"
@@ -9,6 +10,7 @@ import (
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type PostgresTableAnalyzer struct {
@@ -161,7 +163,30 @@ func (ta *PostgresTableAnalyzer) EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error) {
return 0, nil
query := `
SELECT reltuples::bigint
FROM pg_class
JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
WHERE pg_namespace.nspname = $1 AND pg_class.relname = $2`
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
var estimate int64
err := ta.db.QueryRow(ctxTimeout, query, tableInfo.Schema, tableInfo.Table).Scan(&estimate)
if err != nil {
return 0, err
}
if estimate < 0 {
countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, tableInfo.Schema, tableInfo.Table)
err = ta.db.QueryRow(ctxTimeout, countQuery).Scan(&estimate)
if err != nil {
return 0, err
}
}
return estimate, nil
}
func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn(
@@ -169,7 +194,19 @@ func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn(
tableInfo config.TableInfo,
columnName string,
) (etl.MaxMinColumnResult, error) {
return etl.MaxMinColumnResult{}, nil
query := fmt.Sprintf(`SELECT MIN("%s"), MAX("%s") FROM "%s"."%s"`,
columnName, columnName, tableInfo.Schema, tableInfo.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
result := etl.MaxMinColumnResult{}
err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max)
if err != nil {
return etl.MaxMinColumnResult{}, err
}
return result, nil
}
func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
@@ -179,5 +216,78 @@ func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
maxPartitions int64,
rangeConstraint config.RangeConfig,
) ([]models.Partition, error) {
return []models.Partition{}, nil
whereClause := ""
args := []any{maxPartitions}
if rangeConstraint.Min != nil || rangeConstraint.Max != nil {
var conditions []string
if rangeConstraint.Min != nil {
minOp := ">"
if rangeConstraint.IsMinInclusive {
minOp = ">="
}
args = append(args, *rangeConstraint.Min)
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, minOp, len(args)))
}
if rangeConstraint.Max != nil {
maxOp := "<"
if rangeConstraint.IsMaxInclusive {
maxOp = "<="
}
args = append(args, *rangeConstraint.Max)
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, maxOp, len(args)))
}
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
query := fmt.Sprintf(`
SELECT MIN("%s") AS lower_limit, MAX("%s") AS upper_limit
FROM (
SELECT "%s", NTILE($1) OVER (ORDER BY "%s") AS batch_id
FROM "%s"."%s" %s
) AS t
GROUP BY batch_id
ORDER BY batch_id`,
partitionColumn,
partitionColumn,
partitionColumn,
partitionColumn,
tableInfo.Schema,
tableInfo.Table,
whereClause)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
rows, err := ta.db.Query(ctxTimeout, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
partitions := make([]models.Partition, 0, maxPartitions)
for rows.Next() {
partition := models.Partition{
Id: uuid.New(),
HasRange: true,
RetryCounter: 0,
Range: models.PartitionRange{
IsMinInclusive: true,
IsMaxInclusive: true,
},
}
if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil {
return nil, err
}
partitions = append(partitions, partition)
}
if err := rows.Err(); err != nil {
return nil, err
}
return partitions, nil
}

View File

@@ -59,6 +59,49 @@ func computeTransformationPlan(columns []models.ColumnType) []etl.ColumnTransfor
return plan
}
func computePostgresTransformationPlan(columns []models.ColumnType) []etl.ColumnTransformPlan {
var plan []etl.ColumnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "uuid":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return bigEndianToMssqlUuid(b)
}
return v, nil
},
})
case "geometry":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return ewkbToMssqlGeo(b, false)
}
return v, nil
},
})
case "geography":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return ewkbToMssqlGeo(b, true)
}
return v, nil
},
})
}
}
return plan
}
func computeStorageTransformationPlan(
ctx context.Context,
azureClient *azure.Client,

View File

@@ -0,0 +1,72 @@
package transformers
import (
"context"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type PostgresTransformer struct {
sourceTable config.SourceTableInfo
}
func NewPostgresTransformer(sourceTable config.SourceTableInfo) etl.Transformer {
return &PostgresTransformer{sourceTable: sourceTable}
}
func (pgTr *PostgresTransformer) Consume(
ctx context.Context,
columns []models.ColumnType,
retryConfig config.RetryConfig,
batchSize int,
chBatchesIn <-chan models.Batch,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
) {
transformationPlan := computePostgresTransformationPlan(columns)
acc := &batchAccumulator{batchSize: batchSize}
for {
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
acc.flush(ctx, chBatchesOut, wgActiveBatches)
return
}
if len(transformationPlan) > 0 {
if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil {
sendTransformError(ctx, err, chJobErrorsOut)
return
}
}
if batchSize <= 0 {
wgActiveBatches.Add(1)
select {
case chBatchesOut <- batch:
case <-ctx.Done():
wgActiveBatches.Done()
return
}
continue
}
acc.add(batch)
if acc.ready() {
if !acc.flush(ctx, chBatchesOut, wgActiveBatches) {
return
}
}
}
}
}

View File

@@ -4,6 +4,8 @@ import (
"encoding/binary"
"errors"
"time"
mssqlclrgeo "github.com/gaspardle/go-mssqlclrgeo"
)
func mssqlUuidToBigEndian(mssqlUuid []byte) ([]byte, error) {
@@ -62,6 +64,51 @@ func ensureUTC(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
}
func bigEndianToMssqlUuid(pgUuid []byte) ([]byte, error) {
if len(pgUuid) != 16 {
return nil, errors.New("Invalid uuid")
}
mssqlUuid := make([]byte, 16)
mssqlUuid[0], mssqlUuid[1], mssqlUuid[2], mssqlUuid[3] = pgUuid[3], pgUuid[2], pgUuid[1], pgUuid[0]
mssqlUuid[4], mssqlUuid[5] = pgUuid[5], pgUuid[4]
mssqlUuid[6], mssqlUuid[7] = pgUuid[7], pgUuid[6]
copy(mssqlUuid[8:], pgUuid[8:])
return mssqlUuid, nil
}
func ewkbToMssqlGeo(ewkb []byte, isGeography bool) ([]byte, error) {
if len(ewkb) < 5 {
return nil, errors.New("Invalid ewkb")
}
var byteOrder binary.ByteOrder
if ewkb[0] == 0 {
byteOrder = binary.BigEndian
} else {
byteOrder = binary.LittleEndian
}
wkbType := byteOrder.Uint32(ewkb[1:5])
var wkb []byte
if wkbType&sridFlag != 0 {
if len(ewkb) < 9 {
return nil, errors.New("Invalid ewkb: SRID flag set but data too short")
}
clearType := wkbType &^ uint32(sridFlag)
wkb = make([]byte, len(ewkb)-4)
wkb[0] = ewkb[0]
byteOrder.PutUint32(wkb[1:5], clearType)
copy(wkb[5:], ewkb[9:])
} else {
wkb = ewkb
}
return mssqlclrgeo.WkbToUdtGeo(wkb, isGeography)
}
func ToInt64(v any) (int64, bool) {
switch t := v.(type) {
case int: