diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index c17e5f1..aeb06a9 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -9,6 +9,7 @@ import ( "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers" @@ -17,6 +18,13 @@ import ( "golang.org/x/sync/errgroup" ) +func newTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { + if db.GetDialect() == "postgres" { + return table_analyzers.NewPostgresTableAnalyzer(db) + } + return table_analyzers.NewMssqlTableAnalyzer(db) +} + func main() { configureLog() checkExpiry() @@ -163,8 +171,8 @@ func processMigrationJobs( chJobs := make(chan config.Job, len(jobs)) var wgJobs sync.WaitGroup - sourceTableAnalyzer := table_analyzers.NewMssqlTableAnalyzer(sourceDb) - targetTableAnalyzer := table_analyzers.NewPostgresTableAnalyzer(targetDb) + sourceTableAnalyzer := newTableAnalyzer(sourceDb) + targetTableAnalyzer := newTableAnalyzer(targetDb) extractor := extractors.NewExtractor(sourceDb) loader := loaders.NewGenericLoader(targetDb) @@ -181,6 +189,7 @@ func processMigrationJobs( azureClient, loader, job, + sourceDb.GetDialect(), targetDb.GetDialect(), ) diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index 13c9275..9e7c814 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -46,9 +46,15 @@ func processMigrationJob( azureClient *azure.Client, loader loaders.GenericLoader, job config.Job, + sourceDbType string, targetDbType string, ) models.JobResult { - transformer := transformers.NewMssqlTransformer(job.ToStorage, job.SourceTable, azureClient) + var transformer etl.Transformer + if sourceDbType == "postgres" { + transformer = transformers.NewPostgresTransformer(job.SourceTable) + } else { + transformer = transformers.NewMssqlTransformer(job.ToStorage, job.SourceTable, azureClient) + } localCtx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/internal/app/etl/table_analyzers/postgres.go b/internal/app/etl/table_analyzers/postgres.go index 194eae4..33f0517 100644 --- a/internal/app/etl/table_analyzers/postgres.go +++ b/internal/app/etl/table_analyzers/postgres.go @@ -2,6 +2,7 @@ package table_analyzers import ( "context" + "fmt" "strings" "time" @@ -9,6 +10,7 @@ import ( dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" + "github.com/google/uuid" ) type PostgresTableAnalyzer struct { @@ -161,7 +163,30 @@ func (ta *PostgresTableAnalyzer) EstimateTotalRows( ctx context.Context, tableInfo config.TableInfo, ) (int64, error) { - return 0, nil + query := ` +SELECT reltuples::bigint +FROM pg_class +JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace +WHERE pg_namespace.nspname = $1 AND pg_class.relname = $2` + + ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + var estimate int64 + err := ta.db.QueryRow(ctxTimeout, query, tableInfo.Schema, tableInfo.Table).Scan(&estimate) + if err != nil { + return 0, err + } + + if estimate < 0 { + countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, tableInfo.Schema, tableInfo.Table) + err = ta.db.QueryRow(ctxTimeout, countQuery).Scan(&estimate) + if err != nil { + return 0, err + } + } + + return estimate, nil } func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn( @@ -169,7 +194,19 @@ func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn( tableInfo config.TableInfo, columnName string, ) (etl.MaxMinColumnResult, error) { - return etl.MaxMinColumnResult{}, nil + query := fmt.Sprintf(`SELECT MIN("%s"), MAX("%s") FROM "%s"."%s"`, + columnName, columnName, tableInfo.Schema, tableInfo.Table) + + ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + result := etl.MaxMinColumnResult{} + err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max) + if err != nil { + return etl.MaxMinColumnResult{}, err + } + + return result, nil } func (ta *PostgresTableAnalyzer) CalculatePartitionRanges( @@ -179,5 +216,78 @@ func (ta *PostgresTableAnalyzer) CalculatePartitionRanges( maxPartitions int64, rangeConstraint config.RangeConfig, ) ([]models.Partition, error) { - return []models.Partition{}, nil + whereClause := "" + args := []any{maxPartitions} + + if rangeConstraint.Min != nil || rangeConstraint.Max != nil { + var conditions []string + if rangeConstraint.Min != nil { + minOp := ">" + if rangeConstraint.IsMinInclusive { + minOp = ">=" + } + args = append(args, *rangeConstraint.Min) + conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, minOp, len(args))) + } + if rangeConstraint.Max != nil { + maxOp := "<" + if rangeConstraint.IsMaxInclusive { + maxOp = "<=" + } + args = append(args, *rangeConstraint.Max) + conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, maxOp, len(args))) + } + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + query := fmt.Sprintf(` +SELECT MIN("%s") AS lower_limit, MAX("%s") AS upper_limit +FROM ( + SELECT "%s", NTILE($1) OVER (ORDER BY "%s") AS batch_id + FROM "%s"."%s" %s +) AS t +GROUP BY batch_id +ORDER BY batch_id`, + partitionColumn, + partitionColumn, + partitionColumn, + partitionColumn, + tableInfo.Schema, + tableInfo.Table, + whereClause) + + ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + rows, err := ta.db.Query(ctxTimeout, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + partitions := make([]models.Partition, 0, maxPartitions) + + for rows.Next() { + partition := models.Partition{ + Id: uuid.New(), + HasRange: true, + RetryCounter: 0, + Range: models.PartitionRange{ + IsMinInclusive: true, + IsMaxInclusive: true, + }, + } + + if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil { + return nil, err + } + + partitions = append(partitions, partition) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return partitions, nil } diff --git a/internal/app/etl/transformers/plan.go b/internal/app/etl/transformers/plan.go index 2cea12b..f03f286 100644 --- a/internal/app/etl/transformers/plan.go +++ b/internal/app/etl/transformers/plan.go @@ -59,6 +59,49 @@ func computeTransformationPlan(columns []models.ColumnType) []etl.ColumnTransfor return plan } +func computePostgresTransformationPlan(columns []models.ColumnType) []etl.ColumnTransformPlan { + var plan []etl.ColumnTransformPlan + + for i, col := range columns { + switch col.SystemType() { + case "uuid": + plan = append(plan, etl.ColumnTransformPlan{ + Index: i, + Fn: func(v any) (any, error) { + if b, ok := v.([]byte); ok && b != nil { + return bigEndianToMssqlUuid(b) + } + return v, nil + }, + }) + + case "geometry": + plan = append(plan, etl.ColumnTransformPlan{ + Index: i, + Fn: func(v any) (any, error) { + if b, ok := v.([]byte); ok && b != nil { + return ewkbToMssqlGeo(b, false) + } + return v, nil + }, + }) + + case "geography": + plan = append(plan, etl.ColumnTransformPlan{ + Index: i, + Fn: func(v any) (any, error) { + if b, ok := v.([]byte); ok && b != nil { + return ewkbToMssqlGeo(b, true) + } + return v, nil + }, + }) + } + } + + return plan +} + func computeStorageTransformationPlan( ctx context.Context, azureClient *azure.Client, diff --git a/internal/app/etl/transformers/postgres.go b/internal/app/etl/transformers/postgres.go new file mode 100644 index 0000000..6e9c5c8 --- /dev/null +++ b/internal/app/etl/transformers/postgres.go @@ -0,0 +1,72 @@ +package transformers + +import ( + "context" + "sync" + + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" +) + +type PostgresTransformer struct { + sourceTable config.SourceTableInfo +} + +func NewPostgresTransformer(sourceTable config.SourceTableInfo) etl.Transformer { + return &PostgresTransformer{sourceTable: sourceTable} +} + +func (pgTr *PostgresTransformer) Consume( + ctx context.Context, + columns []models.ColumnType, + retryConfig config.RetryConfig, + batchSize int, + chBatchesIn <-chan models.Batch, + chBatchesOut chan<- models.Batch, + chJobErrorsOut chan<- custom_errors.JobError, + wgActiveBatches *sync.WaitGroup, +) { + transformationPlan := computePostgresTransformationPlan(columns) + + acc := &batchAccumulator{batchSize: batchSize} + + for { + select { + case <-ctx.Done(): + return + + case batch, ok := <-chBatchesIn: + if !ok { + acc.flush(ctx, chBatchesOut, wgActiveBatches) + return + } + + if len(transformationPlan) > 0 { + if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil { + sendTransformError(ctx, err, chJobErrorsOut) + return + } + } + + if batchSize <= 0 { + wgActiveBatches.Add(1) + select { + case chBatchesOut <- batch: + case <-ctx.Done(): + wgActiveBatches.Done() + return + } + continue + } + + acc.add(batch) + if acc.ready() { + if !acc.flush(ctx, chBatchesOut, wgActiveBatches) { + return + } + } + } + } +} diff --git a/internal/app/etl/transformers/utils.go b/internal/app/etl/transformers/utils.go index 00b3939..24f4360 100644 --- a/internal/app/etl/transformers/utils.go +++ b/internal/app/etl/transformers/utils.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "errors" "time" + + mssqlclrgeo "github.com/gaspardle/go-mssqlclrgeo" ) func mssqlUuidToBigEndian(mssqlUuid []byte) ([]byte, error) { @@ -62,6 +64,51 @@ func ensureUTC(t time.Time) time.Time { return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) } +func bigEndianToMssqlUuid(pgUuid []byte) ([]byte, error) { + if len(pgUuid) != 16 { + return nil, errors.New("Invalid uuid") + } + + mssqlUuid := make([]byte, 16) + mssqlUuid[0], mssqlUuid[1], mssqlUuid[2], mssqlUuid[3] = pgUuid[3], pgUuid[2], pgUuid[1], pgUuid[0] + mssqlUuid[4], mssqlUuid[5] = pgUuid[5], pgUuid[4] + mssqlUuid[6], mssqlUuid[7] = pgUuid[7], pgUuid[6] + copy(mssqlUuid[8:], pgUuid[8:]) + + return mssqlUuid, nil +} + +func ewkbToMssqlGeo(ewkb []byte, isGeography bool) ([]byte, error) { + if len(ewkb) < 5 { + return nil, errors.New("Invalid ewkb") + } + + var byteOrder binary.ByteOrder + if ewkb[0] == 0 { + byteOrder = binary.BigEndian + } else { + byteOrder = binary.LittleEndian + } + + wkbType := byteOrder.Uint32(ewkb[1:5]) + + var wkb []byte + if wkbType&sridFlag != 0 { + if len(ewkb) < 9 { + return nil, errors.New("Invalid ewkb: SRID flag set but data too short") + } + clearType := wkbType &^ uint32(sridFlag) + wkb = make([]byte, len(ewkb)-4) + wkb[0] = ewkb[0] + byteOrder.PutUint32(wkb[1:5], clearType) + copy(wkb[5:], ewkb[9:]) + } else { + wkb = ewkb + } + + return mssqlclrgeo.WkbToUdtGeo(wkb, isGeography) +} + func ToInt64(v any) (int64, bool) { switch t := v.(type) { case int: diff --git a/openspec/changes/bidirectional-transforms/plan.md b/openspec/changes/bidirectional-transforms/plan.md new file mode 100644 index 0000000..82bc729 --- /dev/null +++ b/openspec/changes/bidirectional-transforms/plan.md @@ -0,0 +1,137 @@ +# Plan: Bidirectional Transformation Support + +## Goal + +Make the transformation pipeline direction-aware. Currently hardcoded to MSSQL → PG; add support for PG → MSSQL by applying inverse transformations when `SourceDbType == "postgres"`. + +Excluded: `to_storage` Azure blob upload (not reversible). + +--- + +## Hardcoded wiring to fix + +| File | Line | Change | +|---|---|---| +| `cmd/go_migrate/process.go` | 51 | Branch on `SourceDbType`: `"sqlserver"` → `NewMssqlTransformer`, `"postgres"` → `NewPostgresTransformer` | +| `cmd/go_migrate/main.go` | 166–167 | Branch on source/target type for both `TableAnalyzer` selections | + +--- + +## Transformations + +### Forward (MSSQL → PG) — unchanged + +| Column type | Function | File | +|---|---|---| +| `uniqueidentifier` | `mssqlUuidToBigEndian` | `utils.go:9` | +| `geometry`/`geography` | `wkbToEwkbWithSrid` | `utils.go:25` | +| `datetime`/`datetime2` | `ensureUTC` | `utils.go:57` | + +### Inverse (PG → MSSQL) — new + +| PG system type | Action | +|---|---| +| `uuid` | `bigEndianToMssqlUuid`: re-swap bytes [0-3], [4-5], [6-7] | +| `geometry` | `ewkbToMssqlGeo(v, false)`: strip SRID → WKB → `WkbToUdtGeo` | +| `geography` | `ewkbToMssqlGeo(v, true)`: strip SRID → WKB → `WkbToUdtGeo` | +| `timestamp`/`timestamptz` | no-op | + +**Geometry note**: MSSQL rejects plain WKB via bulk protocol. Must use `mssqlclrgeo.WkbToUdtGeo(wkb, isGeography)` (already in go.mod). PG extractor already emits EWKB via `ST_AsEWKB()`. + +--- + +## New utility functions (`transformers/utils.go`) + +### `bigEndianToMssqlUuid(v []byte) []byte` +``` +out[0..3] = v[3,2,1,0] +out[4..5] = v[5,4] +out[6..7] = v[7,6] +out[8..15] = v[8..15] +``` + +### `ewkbToMssqlGeo(ewkb []byte, isGeography bool) ([]byte, error)` +1. Read byte-order flag from `ewkb[0]` +2. Read geometry type word bytes [1..4] +3. If SRID flag (`0x20000000`) is set: strip bytes [5..8], clear flag in type word +4. Call `mssqlclrgeo.WkbToUdtGeo(wkb, isGeography)` + +--- + +## New files + +### `transformers/postgres.go` +```go +func NewPostgresTransformer(...) *Transformer { + // same signature as NewMssqlTransformer + // calls computePostgresTransformationPlan instead + // does NOT call computeStorageTransformationPlan +} +``` + +### `computePostgresTransformationPlan` in `transformers/plan.go` +Iterates `sourceColTypes` (from PG analyzer), applies inverse closures by system type. + +--- + +## PostgreSQL table analyzer stubs to implement (`table_analyzers/postgres.go`) + +Required for PG-as-source partitioned extraction: + +### `EstimateTotalRows` +```sql +SELECT reltuples::bigint FROM pg_class + JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace + WHERE pg_namespace.nspname = $schema AND pg_class.relname = $table +``` +Fallback to `COUNT(*)` if `reltuples < 0`. + +### `QueryMaxMinFromColumn` +```sql +SELECT MIN("col"), MAX("col") FROM "schema"."table" +``` + +### `CalculatePartitionRanges` +Use min/max from above + `rowsPerPartition` to compute boundaries. Mirror the logic from `MssqlTableAnalyzer.CalculatePartitionRanges`. + +--- + +## Test cases + +### TC-1: `bigEndianToMssqlUuid` — round-trip +- Input: run `mssqlUuidToBigEndian` on a known 16-byte MSSQL UUID → produces PG UUID +- Assert: `bigEndianToMssqlUuid(pgUUID)` == original MSSQL UUID bytes +- Also assert nil input → nil output (no panic) + +### TC-2: `bigEndianToMssqlUuid` — known vector +- Input: `[0x6b,0xa7,0xb8,0x10, 0x9d,0xad, 0x11,0xd1, 0x80,0xb4,0x00,0xc0,0x4f,0xd4,0x30,0xc8]` (RFC 4122 nil UUID variant) +- Assert: bytes [0-3] are reversed, [4-5] reversed, [6-7] reversed, [8-15] identical + +### TC-3: `ewkbToMssqlGeo` — geometry round-trip +- Input: generate a polygon via `go-geom` + `wkb.Marshal` → plain WKB +- Forward: run `wkbToEwkbWithSrid` → EWKB +- Inverse: run `ewkbToMssqlGeo(ewkb, false)` → CLR/UDT bytes +- Assert: no error, output is non-empty `[]byte` + +### TC-4: `ewkbToMssqlGeo` — nil input +- Input: nil +- Assert: returns nil, nil (no panic) + +### TC-5: `ewkbToMssqlGeo` — EWKB without SRID flag +- Input: plain WKB (no SRID flag set) +- Assert: function still calls `WkbToUdtGeo` and returns without error + +### TC-6: Transformer factory selection +- Given `SourceDbType == "postgres"` → `NewPostgresTransformer` is selected +- Given `SourceDbType == "sqlserver"` → `NewMssqlTransformer` is selected + +--- + +## Files changed (summary) + +1. `cmd/go_migrate/process.go` — transformer factory branch +2. `cmd/go_migrate/main.go` — analyzer selection branch +3. `internal/app/etl/transformers/utils.go` — 2 new functions +4. `internal/app/etl/transformers/plan.go` — `computePostgresTransformationPlan` +5. `internal/app/etl/transformers/postgres.go` *(new)* +6. `internal/app/etl/table_analyzers/postgres.go` — 3 stub implementations