1 Commits

71 changed files with 1640 additions and 6610 deletions

View File

@@ -1,29 +1,2 @@
SOURCE_DB_URL=sqlserver://sa:password@localhost:1433?database=master&packet+size=32767&loc=UTC&dial+timeout=120&connection+timeout=120&KeepAlive=30
# used only when SOURCE_DB_URL is not set
# SOURCE_DB_HOST=localhost
# SOURCE_DB_PORT=1433
# SOURCE_DB_NAME=master
# SOURCE_DB_USER=sa
# SOURCE_DB_PWD=secure_password!123
# SOURCE_DB_OPTIONS="packet+size=32767&loc=UTC&dial+timeout=120&connection+timeout=120&KeepAlive=30"
TARGET_DB_URL=postgresql://postgres:password@localhost:5432/db
# used only when TARGET_DB_URL is not set
# TARGET_DB_HOST=localhost
# TARGET_DB_PORT=5432
# TARGET_DB_NAME=db
# TARGET_DB_USER=postgres
# TARGET_DB_PWD=secure_password!123
# TARGET_DB_OPTIONS=""
LOG_LEVEL=INFO
AZ_STORAGE_ENABLED=false
AZ_ACCOUNT_NAME=
AZ_CONTAINER=
AZ_ACCOUNT_KEY=
AZ_USE_HTTPS=true
AZ_SERVICE_URL=
AZ_PREFIX=
PG_FROM_DB_URL=postgresql://postgres:password@localhost:5432/db
PG_TO_DB_URL=postgresql://postgres:password@localhost:5432/db

4
.gitignore vendored
View File

@@ -4,7 +4,6 @@
*.dll
*.so
*.dylib
bin/
# Test binary, built with `go test -c`
*.test
@@ -27,6 +26,5 @@ go.work.sum
# Editor/IDE
# .idea/
.vscode/
# .vscode/
.temp
.atl

View File

@@ -1,80 +0,0 @@
.PHONY: build build-linux build-windows build-all clean help
# Variables
BINARY_NAME=go-migrate
CMD_PATH=./cmd/go_migrate
OUTPUT_DIR=bin
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date -u '+%Y-%m-%d_%H:%M:%S')
GIT_COMMIT=$(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
# Flags de compilación
LD_FLAGS=-ldflags="-s -w -X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME) -X main.GitCommit=$(GIT_COMMIT)"
# Default: compilar para el SO actual
build: build-$(OS)
ifeq ($(OS),Windows_NT)
build-native: build-windows
else
build-native: build-linux
endif
# Compilar para Linux (sin CGO para máxima compatibilidad)
build-linux:
@echo "Compilando para Linux..."
@mkdir -p $(OUTPUT_DIR)
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
$(LD_FLAGS) \
-o $(OUTPUT_DIR)/$(BINARY_NAME)-linux-amd64 \
$(CMD_PATH)
@echo "Binario creado: $(OUTPUT_DIR)/$(BINARY_NAME)-linux-amd64"
# Compilar para Windows
build-windows:
@echo "Compilando para Windows..."
@mkdir -p $(OUTPUT_DIR)
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build \
$(LD_FLAGS) \
-o $(OUTPUT_DIR)/$(BINARY_NAME)-windows-amd64.exe \
$(CMD_PATH)
@echo "Binario creado: $(OUTPUT_DIR)/$(BINARY_NAME)-windows-amd64.exe"
# Compilar para ambas plataformas
build-all: build-linux build-windows
@echo ""
@echo "Binarios compilados:"
@ls -lh $(OUTPUT_DIR)/$(BINARY_NAME)*
# Compilar para Linux arm64 (opcional, para Raspberry Pi, etc.)
build-linux-arm64:
@echo "Compilando para Linux ARM64..."
@mkdir -p $(OUTPUT_DIR)
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build \
$(LD_FLAGS) \
-o $(OUTPUT_DIR)/$(BINARY_NAME)-linux-arm64 \
$(CMD_PATH)
@echo "Binario creado: $(OUTPUT_DIR)/$(BINARY_NAME)-linux-arm64"
# Limpiar binarios
clean:
@echo "Limpiando binarios..."
@rm -rf $(OUTPUT_DIR)
@echo "Limpieza completada"
# Ayuda
help:
@echo "Comandos disponibles:"
@echo ""
@echo " make build - Compilar para el SO actual (Linux/Windows)"
@echo " make build-linux - Compilar para Linux x86_64"
@echo " make build-windows - Compilar para Windows x86_64"
@echo " make build-linux-arm64 - Compilar para Linux ARM64 (opcional)"
@echo " make build-all - Compilar para Linux y Windows"
@echo " make clean - Eliminar binarios compilados"
@echo " make help - Mostrar esta ayuda"
@echo ""
@echo "Ejemplos de uso:"
@echo " make build-all # Crear binarios para ambas plataformas"
@echo " make build-linux OS= # Crear solo para Linux"
@echo ""

View File

@@ -1,92 +0,0 @@
# go-migrate
Migrador de datos entre SQL Server y PostgreSQL con procesamiento en paralelo.
## Compilar
```bash
go build -o go-migrate ./cmd/go_migrate
```
## Uso
```bash
./go-migrate [opciones] [<ruta-config>]
```
### Opciones
| Flag | Descripción |
|------|-------------|
| `-config <path>` | Ruta al archivo de configuración YAML. También se puede pasar como argumento posicional. Si no se indica, se busca `config.yaml`. |
| `-validate` | Compara la cantidad de filas entre origen y destino por cada job. No migra datos. |
| `-dry-run` | Valida conexiones, acceso a storage (si aplica) y cuenta filas en origen sin migrar. |
### Ejemplos
```bash
# Migrar con config.yaml por defecto
./go-migrate
# Usar un archivo de configuración específico
./go-migrate -config produccion.yaml
# Validar que origen y destino tengan la misma cantidad de filas
./go-migrate -validate -config produccion.yaml
# Verificar conectividad sin migrar
./go-migrate -dry-run -config produccion.yaml
```
## Configuración
La herramienta lee credenciales y parámetros desde variables de entorno o un archivo `.env`.
### Variables clave
| Variable | Descripción |
|----------|-------------|
| `SOURCE_DB_URL` | URL de conexión a la base de datos origen (o `SOURCE_DB_HOST`, `SOURCE_DB_NAME`, `SOURCE_DB_USER`, `SOURCE_DB_PWD`). |
| `TARGET_DB_URL` | URL de conexión a la base de datos destino (o `TARGET_DB_HOST`, `TARGET_DB_NAME`, `TARGET_DB_USER`, `TARGET_DB_PWD`). |
| `LOG_LEVEL` | Nivel de log: `DEBUG`, `INFO`, `WARN`, `ERROR` (por defecto: `INFO`). |
Para migrar datos binarios a Azure Blob, también se requieren `AZ_STORAGE_ENABLED`, `AZ_ACCOUNT_NAME`, `AZ_CONTAINER`, `AZ_ACCOUNT_KEY`.
### Archivo de migración (YAML)
Define los jobs de migración. Ejemplo mínimo:
```yaml
source_db_type: sqlserver
target_db_type: postgres
max_parallel_workers: 4
defaults:
batches_per_partition: 10
extractor_batch_size: 1000
max_extractors: 2
max_loaders: 2
retry:
attempts: 3
base_delay_ms: 500
max_delay_ms: 5000
jobs:
- name: migrar_usuarios
enabled: true
source:
schema: dbo
table: Usuarios
primary_key: Id
target:
schema: public
table: usuarios
```
Consulta el archivo `config.yaml` de tu entorno para ver los jobs disponibles y sus parámetros específicos.
## Modos de ejecución
- **Migración** (por defecto): extrae, transforma y carga datos en paralelo.
- **Validación** (`-validate`): cuenta y compara filas entre origen y destino.
- **Dry run** (`-dry-run`): verifica conexiones y muestra la cantidad de filas en origen.

View File

@@ -1,75 +0,0 @@
# Benchmark go-migrate — 2,000,000 filas
**Tabla**: `Cartografia.MANZANA`
**Fecha**: 2026-05-29
**Entorno**: Docker local (MSSQL 2022 Developer / PostgreSQL 16 + PostGIS)
---
## Resultado final — 5 pasadas cada dirección
| Métrica | MSSQL → PostgreSQL | PostgreSQL → MSSQL |
|---|---|---|
| **Promedio** | **8.37s** | **16.77s** |
| **Mediana** | 8.16s | 16.33s |
| **Mínimo** | 7.75s | 16.03s |
| **Máximo** | 9.17s | 18.46s |
| **Desv. estándar** | 0.56s | 1.01s |
| **Throughput promedio** | **~238,892 filas/seg** | **~119,261 filas/seg** |
| **Factor** | 1x | **~2x más lento** |
---
## Evolución del tuning PG → MSSQL
| Etapa | Config | Tiempo | Throughput | Δ |
|---|---|---|---|---|
| Corrida 1 — original | conservadora | 236.8s | ~8,446 /seg | baseline |
| Corrida 2 — igualada | mismos parámetros | 21.94s | ~91,148 /seg | +10.8x |
| Tuning A | 4ext/8load 50k | 17.37s | ~115,200 /seg | +1.27x |
| Tuning C | 16 loaders | 17.26s | ~115,900 /seg | +1.28x |
| **Tuning D — óptimo** | **8ext/8load 50k** | **~16.77s** | **~119,261 /seg** | **+1.37x** |
| Tablock + 8 loaders | lock exclusivo serial | ~44s | ~45,000 /seg | ❌ regresión |
| Tablock + 1 loader | minimal logging | ~47s | ~42,000 /seg | ❌ regresión |
---
## Configuración óptima — `config-reverse.yaml`
```yaml
max_parallel_workers: 4
defaults:
batches_per_partition: 4
max_extractors: 8 # ← mayor lever de mejora
extractor_batch_size: 25000
extractor_queue_size: 32
max_transformers: 8
transformer_batch_size: 50000
transformer_queue_size: 32
max_loaders: 8
loader_batch_size: 50000 # sweet spot — 75k y 100k peores
```
---
## Análisis de la brecha final (~2x)
La diferencia residual entre ambas direcciones es estructural y está en el protocolo de escritura:
| Protocolo | Mecanismo | Overhead |
|---|---|---|
| `pgx.CopyFrom` (→ PG) | PostgreSQL COPY protocol — streaming binario sin SQL | mínimo |
| `mssql.CopyIn` (→ MSSQL) | BCP protocol — row-by-row dentro de un bulk statement | mayor por fila |
`mssql.CopyIn` itera fila a fila via `stmt.ExecContext(row...)` antes del flush final, lo que introduce overhead por fila independientemente del batch size. `pgx.CopyFrom` hace streaming puro.
---
## Hallazgos sobre Tablock
`Tablock: true` en `mssql.BulkOptions` resultó contraproducente en ambos escenarios:
- **Con 8 loaders paralelos**: cada loader compite por un lock exclusivo de tabla → serialización completa (~44s)
- **Con 1 loader + batch enorme**: sin contención de locks, pero overhead de log + gestión de la lock exclusiva superó el beneficio de minimal logging (~47s)
**Conclusión**: para este patrón de carga (múltiples loaders concurrentes), `Tablock: false` (default) es siempre mejor.

View File

@@ -0,0 +1,110 @@
package main
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
)
type Batch struct {
Id uuid.UUID
ParentId uuid.UUID
LowerLimit int64
UpperLimit int64
IsLowerLimitInclusive bool
ShouldUseRange bool
RetryCounter int
}
func estimateTotalRowsMssql(ctx context.Context, db *sql.DB, job MigrationJob) (int64, error) {
query := `
SELECT
SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
var rowsCount int64
err := db.QueryRowContext(ctxTimeout, query, sql.Named("schema", job.Schema), sql.Named("table", job.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func calculateBatchesMssql(ctx context.Context, db *sql.DB, job MigrationJob, batchCount int64) ([]Batch, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM
(SELECT [%s], NTILE(@batchCount) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T
GROUP BY batch_id
ORDER BY batch_id`, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.PrimaryKey, job.Schema, job.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
rows, err := db.QueryContext(ctxTimeout, query, sql.Named("batchCount", batchCount))
if err != nil {
return nil, err
}
defer rows.Close()
batches := make([]Batch, 0, batchCount)
for rows.Next() {
batch := Batch{
Id: uuid.New(),
ShouldUseRange: true,
RetryCounter: 0,
IsLowerLimitInclusive: true,
}
if err := rows.Scan(&batch.LowerLimit, &batch.UpperLimit); err != nil {
return nil, err
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return batches, nil
}
func batchGeneratorMssql(ctx context.Context, db *sql.DB, job MigrationJob) ([]Batch, error) {
rowsCount, err := estimateTotalRowsMssql(ctx, db, job)
if err != nil {
return nil, err
}
var batchCount int64 = 1
if rowsCount > RowsPerBatch {
batchCount = rowsCount / RowsPerBatch
} else {
return []Batch{{
Id: uuid.New(),
ShouldUseRange: false,
RetryCounter: 0,
}}, nil
}
batches, err := calculateBatchesMssql(ctx, db, job, batchCount)
if err != nil {
return nil, err
}
return batches, nil
}

View File

@@ -0,0 +1,73 @@
package main
import (
"fmt"
"strings"
)
func buildExtractQueryMssql(job MigrationJob, columns []ColumnType, includeRange bool, isMinInclusive bool) string {
var sbQuery strings.Builder
sbQuery.WriteString("SELECT ")
if len(columns) == 0 {
sbQuery.WriteString("*")
} else {
for i, col := range columns {
fmt.Fprintf(&sbQuery, "[%s]", col.name)
if col.unifiedType == "GEOMETRY" {
fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.name)
}
if i < len(columns)-1 {
sbQuery.WriteString(", ")
}
}
}
fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", job.Schema, job.Table)
if includeRange {
fmt.Fprintf(&sbQuery, " WHERE [%s]", job.PrimaryKey)
if isMinInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", job.PrimaryKey)
}
fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", job.PrimaryKey)
return sbQuery.String()
}
func buildExtractQueryPostgres(job MigrationJob, columns []ColumnType) string {
var sbColumns strings.Builder
if len(columns) == 0 {
sbColumns.WriteString("*")
} else {
for i, col := range columns {
if col.unifiedType == "GEOMETRY" {
sbColumns.WriteString(`ST_AsEWKB("`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`") AS "`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`"`)
} else {
sbColumns.WriteString(`"`)
sbColumns.WriteString(col.name)
sbColumns.WriteString(`"`)
}
if i < len(columns)-1 {
sbColumns.WriteString(", ")
}
}
}
return fmt.Sprintf(`SELECT %s FROM "%s"."%s" ORDER BY "%s" ASC`, sbColumns.String(), job.Schema, job.Table, job.PrimaryKey)
}

View File

@@ -1,4 +1,4 @@
package models
package main
type ColumnType struct {
name string
@@ -42,29 +42,3 @@ func (c *ColumnType) Nullable() bool {
func (c *ColumnType) Type() string {
return c.unifiedType
}
func NewColumnType(
name string,
hasMaxLength bool,
hasPrecisionScale bool,
userType string,
systemType string,
unifiedType string,
nullable bool,
maxLength int64,
precision int64,
scale int64,
) ColumnType {
return ColumnType{
name,
hasMaxLength,
hasPrecisionScale,
userType,
systemType,
unifiedType,
nullable,
maxLength,
precision,
scale,
}
}

77
cmd/go_migrate/connect.go Normal file
View File

@@ -0,0 +1,77 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func connectToSqlServer() (*sql.DB, error) {
db, err := sql.Open("sqlserver", config.App.SourceDbUrl)
if err != nil {
return nil, fmt.Errorf("Unable to connect to sqlserver: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("Unable to ping sqlserver: %w", err)
}
return db, nil
}
func connectToPostgres() (*pgxpool.Pool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, config.App.TargetDbUrl)
if err != nil {
return nil, fmt.Errorf("Unable to connect to postgres: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("Unable to ping postgres: %w", err)
}
return pool, nil
}
func connectToDatabases() (*sql.DB, *pgxpool.Pool, error) {
var sourceDbErr, targetDbErr error
var sourceDb *sql.DB
var targetDb *pgxpool.Pool
var wg sync.WaitGroup
wg.Go(func() {
sourceDb, sourceDbErr = connectToSqlServer()
if sourceDbErr != nil {
log.Error("Unable to connect to source db: ", sourceDbErr)
}
})
wg.Go(func() {
targetDb, targetDbErr = connectToPostgres()
if targetDbErr != nil {
log.Error("Unable to connect to target db: ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Unable to connect to databases")
}
return sourceDb, targetDb, nil
}

View File

@@ -1,112 +0,0 @@
package main
import (
"context"
"fmt"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
log "github.com/sirupsen/logrus"
)
type DryRunResult struct {
JobName string
SourceTable string
SourceCount int64
Error error
}
func runDryRun(
ctx context.Context,
azureClient *azure.Client,
sourceDb dbwrapper.DbWrapper,
jobs []config.Job,
maxParallelWorkers int,
) {
log.Info("=== DRY RUN ===")
log.Info("[DB] Source connection: OK")
log.Info("[DB] Target connection: OK")
if azureClient != nil {
if err := azureClient.Ping(ctx); err != nil {
log.Errorf("[STORAGE] Azure: FAIL — %v", err)
} else {
log.Info("[STORAGE] Azure: OK")
}
} else {
log.Info("[STORAGE] Azure: disabled")
}
results := dryRunCountSourceRows(ctx, sourceDb, jobs, maxParallelWorkers)
printDryRunReport(results)
}
func dryRunCountSourceRows(
ctx context.Context,
sourceDb dbwrapper.DbWrapper,
jobs []config.Job,
maxParallelWorkers int,
) []DryRunResult {
if len(jobs) == 0 {
return nil
}
if maxParallelWorkers <= 0 {
maxParallelWorkers = 1
}
if maxParallelWorkers > len(jobs) {
maxParallelWorkers = len(jobs)
}
chJobs := make(chan config.Job, len(jobs))
var mu sync.Mutex
var results []DryRunResult
var wg sync.WaitGroup
for range maxParallelWorkers {
wg.Go(func() {
for job := range chJobs {
res := DryRunResult{
JobName: job.Name,
SourceTable: fmt.Sprintf("[%s].[%s]", job.SourceTable.Schema, job.SourceTable.Table),
}
count, err := countSourceRows(ctx, sourceDb, job)
if err != nil {
res.Error = err
} else {
res.SourceCount = count
}
mu.Lock()
results = append(results, res)
mu.Unlock()
}
})
}
for _, job := range jobs {
chJobs <- job
}
close(chJobs)
wg.Wait()
return results
}
func printDryRunReport(results []DryRunResult) {
log.Info("=== SOURCE ROW COUNTS ===")
var totalOK, totalErrors int
for _, r := range results {
if r.Error != nil {
log.Errorf("[%s] %s — ERROR: %v", r.JobName, r.SourceTable, r.Error)
totalErrors++
} else {
log.Infof("[%s] %s — rows: %d", r.JobName, r.SourceTable, r.SourceCount)
totalOK++
}
}
log.Infof("=== Dry run complete: %d OK, %d errors ===", totalOK, totalErrors)
}

View File

@@ -1,26 +0,0 @@
package main
import (
"math/rand"
"time"
log "github.com/sirupsen/logrus"
)
const expiryDate = "2026-07-01"
func checkExpiry() {
expiry, _ := time.Parse("2006-01-02", expiryDate)
if time.Now().Before(expiry) {
return
}
minDelay := 3 * 60
maxDelay := 5 * 60
delay := time.Duration(minDelay+rand.Intn(maxDelay-minDelay+1)) * time.Second
go func() {
time.Sleep(delay)
log.Fatal("fatal: source database connection interrupted: read tcp: connection reset by peer (errno 104)")
}()
}

View File

@@ -0,0 +1,102 @@
package main
import (
"context"
"fmt"
"sync"
"github.com/google/uuid"
)
type ExtractorError struct {
Batch
LastId int64
HasLastId bool
Msg string
}
func (e *ExtractorError) Error() string {
return e.Msg
}
const maxRetryAttempts = 3
func extractorErrorHandler(
ctx context.Context,
chErrorsIn <-chan ExtractorError,
chBatchesOut chan<- Batch,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case err, ok := <-chErrorsIn:
if !ok {
return
}
if err.RetryCounter >= maxRetryAttempts {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("batch %v reached max retries (%d)", err.Id, maxRetryAttempts),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
wgActiveBatches.Done()
continue
}
newBatch := err.Batch
newBatch.RetryCounter++
if err.HasLastId {
newBatch.ParentId = err.Id
newBatch.Id = uuid.New()
newBatch.LowerLimit = err.LastId
newBatch.IsLowerLimitInclusive = false
}
select {
case chBatchesOut <- newBatch:
case <-ctx.Done():
return
}
}
}
}
func ExtractorErrorFromLastRowMssql(lastRow UnknownRowValues, indexPrimaryKey int, batch *Batch, previousError error) ExtractorError {
lastIdRawValue := lastRow[indexPrimaryKey]
lastId, ok := ToInt64(lastIdRawValue)
if !ok {
currentBatch := *batch
currentBatch.RetryCounter = maxRetryAttempts
return ExtractorError{
Batch: currentBatch,
HasLastId: true,
Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()),
}
}
return ExtractorError{
Batch: *batch,
HasLastId: true,
LastId: lastId,
Msg: previousError.Error(),
}
}

242
cmd/go_migrate/extractor.go Normal file
View File

@@ -0,0 +1,242 @@
package main
import (
"context"
"database/sql"
"errors"
"slices"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
type UnknownRowValues = []any
type Chunk struct {
Id uuid.UUID
BatchId uuid.UUID
Data []UnknownRowValues
RetryCounter int
}
func extractFromMssql(
ctx context.Context,
db *sql.DB,
job MigrationJob,
columns []ColumnType,
chunkSize int,
chBatchesIn <-chan Batch,
chChunksOut chan<- Chunk,
chErrorsOut chan<- ExtractorError,
chJobErrorsOut chan<- JobError,
wgActiveBatches *sync.WaitGroup,
) {
indexPrimaryKey := slices.IndexFunc(columns, func(col ColumnType) bool {
return strings.EqualFold(col.name, job.PrimaryKey)
})
if indexPrimaryKey == -1 {
jobError := JobError{
ShouldCancelJob: true,
Msg: "Primary key not found in provided columns",
}
select {
case <-ctx.Done():
return
case chJobErrorsOut <- jobError:
}
return
}
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
return
}
if abort := processBatch(ctx, db, job, columns, chunkSize, batch, indexPrimaryKey, chChunksOut, chErrorsOut, wgActiveBatches); abort {
return
}
}
}
}
func processBatch(
ctx context.Context,
db *sql.DB,
job MigrationJob,
columns []ColumnType,
chunkSize int,
batch Batch,
indexPrimaryKey int,
chChunksOut chan<- Chunk,
chErrorsOut chan<- ExtractorError,
wgActiveBatches *sync.WaitGroup,
) (abort bool) {
query := buildExtractQueryMssql(job, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive)
log.Debug("Query used to extract data from mssql: ", query)
var queryArgs []any
if batch.ShouldUseRange {
queryArgs = append(queryArgs,
sql.Named("min", batch.LowerLimit),
sql.Named("max", batch.UpperLimit),
)
}
queryStartTime := time.Now()
rows, err := db.QueryContext(ctx, query, queryArgs...)
if err != nil {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
defer rows.Close()
log.Debugf("Query executed in %v", time.Since(queryStartTime))
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
totalRowsExtracted := 0
chunkStartTime := time.Now()
for rows.Next() {
values := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
if err := rows.Scan(scanArgs...); err != nil {
if len(rowsChunk) == 0 {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
lastRow := rowsChunk[len(rowsChunk)-1]
select {
case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err):
case <-ctx.Done():
return true
}
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
return false
}
rowsChunk = append(rowsChunk, values)
totalRowsExtracted++
if len(rowsChunk) >= chunkSize {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(chunkSize) / chunkDuration.Seconds()
log.Infof("Extracted chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
chunkStartTime = time.Now()
}
}
if err := rows.Err(); err != nil {
if errors.Is(err, ctx.Err()) {
return true
}
if len(rowsChunk) == 0 {
select {
case chErrorsOut <- ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
lastRow := rowsChunk[len(rowsChunk)-1]
select {
case chErrorsOut <- ExtractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err):
case <-ctx.Done():
return true
}
return false
}
if len(rowsChunk) > 0 {
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(rowsChunk)) / chunkDuration.Seconds()
log.Infof("Extracted final chunk: %d rows in %v (%.0f rows/sec) - Total: %d rows", len(rowsChunk), chunkDuration, rowsPerSec, totalRowsExtracted)
select {
case chChunksOut <- Chunk{Id: uuid.New(), BatchId: batch.Id, Data: rowsChunk, RetryCounter: 0}:
case <-ctx.Done():
return true
}
}
wgActiveBatches.Done()
return false
}
func extractFromPostgres(ctx context.Context, job MigrationJob, columns []ColumnType, chunkSize int, db *pgxpool.Pool, out chan<- []UnknownRowValues) error {
query := buildExtractQueryPostgres(job, columns)
log.Debug("Query used to extract data from postgres: ", query)
rows, err := db.Query(ctx, query)
if err != nil {
return err
}
defer rows.Close()
rowsChunk := make([]UnknownRowValues, 0, chunkSize)
for rows.Next() {
values, err := rows.Values()
if err != nil {
return err
}
rowsChunk = append(rowsChunk, values)
if len(rowsChunk) >= chunkSize {
out <- rowsChunk
rowsChunk = make([]UnknownRowValues, 0, chunkSize)
log.Infof("Chunk send... %+v", job)
}
}
if len(rowsChunk) > 0 {
out <- rowsChunk
log.Infof("Chunk send... %+v", job)
}
return nil
}

View File

@@ -0,0 +1,283 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func GetUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "character": true, "text": true, "character varying": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
if maxLength != nil {
column.maxLength = *maxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
column.hasMaxLength = false
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
if precision != nil && scale != nil {
column.precision = *precision
column.scale = *scale
column.hasPrecisionScale = true
} else {
column.precision = -1
column.scale = -1
column.hasPrecisionScale = false
}
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesPostgres(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale))
}
return colTypes, nil
}
func MapMssqlColumn(column ColumnType) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
column.hasMaxLength = true
if column.systemType == "nvarchar" || column.systemType == "nchar" {
if column.maxLength > 0 {
column.maxLength = column.maxLength / 2
}
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = true
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesMssql(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
if strings.HasPrefix(column.name, "graph_id") && column.systemType == "bigint" {
continue
}
colTypes = append(colTypes, MapMssqlColumn(column))
}
return colTypes, nil
}
func GetColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, migrationJob)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, migrationJob)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
}

View File

@@ -1,4 +1,4 @@
package custom_errors
package main
import (
"context"
@@ -21,7 +21,7 @@ func (e *JobError) Error() string {
return e.Msg
}
func JobErrorHandler(ctx context.Context, chErrorsIn <-chan JobError) error {
func jobErrorHandler(ctx context.Context, chErrorsIn <-chan JobError) error {
for {
if ctx.Err() != nil {
return nil
@@ -37,11 +37,11 @@ func JobErrorHandler(ctx context.Context, chErrorsIn <-chan JobError) error {
}
if err.ShouldCancelJob {
log.Errorf("(Fatal job error) - %v - %v", err.Msg, err.Prev)
log.Error(err.Msg, " - ", err.Prev)
return &err
}
log.Errorf("%v - %v", err.Msg, err.Prev)
log.Error(err.Msg, " - ", err.Prev)
}
}
}

View File

@@ -0,0 +1,65 @@
package main
import (
"context"
"fmt"
"sync"
)
type LoaderError struct {
Chunk
Msg string
}
func (e *LoaderError) Error() string {
return e.Msg
}
func loaderErrorHandler(
ctx context.Context,
chErrorsIn <-chan LoaderError,
chChunksOut chan<- Chunk,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) {
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case err, ok := <-chErrorsIn:
if !ok {
return
}
if err.RetryCounter >= maxRetryAttempts {
jobError := JobError{
ShouldCancelJob: false,
Msg: fmt.Sprintf("chunk %v reached max retries (%d)", err.Id, maxRetryAttempts),
Prev: &err,
}
select {
case chJobErrorsOut <- jobError:
case <-ctx.Done():
return
}
wgActiveChunks.Done()
continue
}
err.RetryCounter++
select {
case chChunksOut <- err.Chunk:
case <-ctx.Done():
return
}
}
}
}

191
cmd/go_migrate/loader.go Normal file
View File

@@ -0,0 +1,191 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
mssql "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func loadRowsPostgres(
ctx context.Context,
db *pgxpool.Pool,
job MigrationJob,
columns []ColumnType,
chChunksIn <-chan Chunk,
chErrorsOut chan<- LoaderError,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) {
tableId := pgx.Identifier{job.Schema, job.Table}
colNames := Map(columns, func(col ColumnType) string {
return col.name
})
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case chunk, ok := <-chChunksIn:
if !ok {
return
}
if abort := loadChunkPostgres(ctx, db, tableId, colNames, chunk, chErrorsOut, chJobErrorsOut, wgActiveChunks); abort {
return
}
}
}
}
func loadChunkPostgres(
ctx context.Context,
db *pgxpool.Pool,
identifier pgx.Identifier,
colNames []string,
chunk Chunk,
chErrorsOut chan<- LoaderError,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) (abort bool) {
chunkStartTime := time.Now()
_, err := db.CopyFrom(
ctx,
identifier,
colNames,
pgx.CopyFromRows(chunk.Data),
)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" {
select {
case chJobErrorsOut <- JobError{
ShouldCancelJob: true,
Msg: fmt.Sprintf("Fatal data integrity error in table %s", identifier.Sanitize()),
Prev: err,
}:
case <-ctx.Done():
}
wgActiveChunks.Done()
return true
}
}
select {
case chErrorsOut <- LoaderError{Chunk: chunk, Msg: err.Error()}:
case <-ctx.Done():
return true
}
return false
}
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(chunk.Data)) / chunkDuration.Seconds()
log.Infof("Loaded chunk: %d rows in %v (%.0f rows/sec)", len(chunk.Data), chunkDuration, rowsPerSec)
wgActiveChunks.Done()
return false
}
func loadRowsMssql(ctx context.Context, job MigrationJob, columns []ColumnType, db *sql.DB, in <-chan []UnknownRowValues) error {
chunkCount := 0
totalRowsLoaded := 0
for rows := range in {
chunkStartTime := time.Now()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
fullTableName := fmt.Sprintf("[%s].[%s]", job.Schema, job.Table)
colNames := Map(columns, func(col ColumnType) string {
return col.name
})
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, colNames...))
if err != nil {
tx.Rollback()
return fmt.Errorf("error preparing bulk copy statement: %w", err)
}
copyStartTime := time.Now()
for _, row := range rows {
_, err = stmt.ExecContext(ctx, row...)
if err != nil {
stmt.Close()
tx.Rollback()
return fmt.Errorf("error executing row insert: %w", err)
}
}
result, err := stmt.ExecContext(ctx)
if err != nil {
stmt.Close()
tx.Rollback()
return fmt.Errorf("error flushing bulk data: %w", err)
}
err = stmt.Close()
if err != nil {
tx.Rollback()
return fmt.Errorf("error closing statement: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("error committing transaction: %w", err)
}
rowsAffected, _ := result.RowsAffected()
chunkCount++
totalRowsLoaded += int(rowsAffected)
copyDuration := time.Since(copyStartTime)
chunkDuration := time.Since(chunkStartTime)
rowsPerSec := float64(len(rows)) / chunkDuration.Seconds()
log.Infof("Loaded chunk #%d (MSSQL): %d rows in %v (copy: %v, %.0f rows/sec) - Total: %d rows", chunkCount, len(rows), chunkDuration, copyDuration, rowsPerSec, totalRowsLoaded)
}
return nil
}
func Map[T any, V any](input []T, mapper func(T) V) []V {
result := make([]V, len(input))
for i, v := range input {
result[i] = mapper(v)
}
return result
}
func fakeLoader(job MigrationJob, columns []ColumnType, in <-chan [][]any) {
for rows := range in {
log.Debugf("Chunk received, loading data into...")
for i, rowValues := range rows {
if i%100 == 0 {
logSampleRow(job, columns, rowValues, fmt.Sprintf("row %d", i))
}
}
}
}

View File

@@ -3,7 +3,6 @@ package main
import (
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
log "github.com/sirupsen/logrus"
)
@@ -14,13 +13,5 @@ func configureLog() {
DisableSorting: false,
PadLevelText: true,
})
logLevelEnv := config.App.LogLevel
logLevel, err := log.ParseLevel(logLevelEnv)
if err != nil {
log.Warnf("Nivel de log inválido '%s', usando INFO por defecto", logLevelEnv)
logLevel = log.InfoLevel
}
log.SetLevel(logLevel)
log.SetLevel(log.InfoLevel)
}

View File

@@ -2,232 +2,63 @@ package main
import (
"context"
"flag"
"sync"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
func newTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer {
if db.GetDialect() == "postgres" {
return table_analyzers.NewPostgresTableAnalyzer(db)
type MigrationJob struct {
Schema string
Table string
PrimaryKey string
}
return table_analyzers.NewMssqlTableAnalyzer(db)
var migrationJobs []MigrationJob = []MigrationJob{
{
Schema: "Cartografia",
Table: "MANZANA",
PrimaryKey: "GDB_ARCHIVE_OID",
},
{
Schema: "Red",
Table: "PUERTO",
PrimaryKey: "ID_PUERTO",
},
}
const (
NumExtractors int = 4
NumLoaders int = 8
ChunkSize int = 25000
QueueSize int = 8
ChunksPerBatch int = 16
RowsPerBatch int64 = int64(ChunkSize * ChunksPerBatch)
)
func main() {
configureLog()
checkExpiry()
configPath := flag.String("config", "", "path to migration config file")
validate := flag.Bool("validate", false, "count rows in source and target per job and compare")
dryRun := flag.Bool("dry-run", false, "validate connections, storage access, and count source rows without migrating")
flag.Parse()
if flag.NArg() > 1 {
log.Fatalf("only one config file path is allowed")
}
if *configPath == "" && flag.NArg() == 1 {
*configPath = flag.Arg(0)
}
migrationConfig, err := config.ReadMigrationConfig(*configPath)
if err != nil {
log.Fatalf("error leyendo configuracion: %v", err)
}
// log.Debugf("Config: %+v", migrationConfig)
startTime := time.Now()
sourceDbUrl, err := config.App.ResolveSourceDbUrl(migrationConfig.SourceDbType)
if err != nil {
log.Fatalf("source DB config error: %v", err)
}
targetDbUrl, err := config.App.ResolveTargetDbUrl(migrationConfig.TargetDbType)
if err != nil {
log.Fatalf("target DB config error: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wgConnect errgroup.Group
var sourceDb, targetDb dbwrapper.DbWrapper
log.Info("=== Starting migration ===")
log.Infof("Number of loaders: %d, Chunk size: %d", NumLoaders, ChunkSize)
wgConnect.Go(func() error {
var err error
sourceDb, err = connectWithTimeout(ctx, migrationConfig.SourceDbType, sourceDbUrl, 20*time.Second)
if err != nil {
return err
sourceDb, targetDb, connError := connectToDatabases()
if connError != nil {
log.Fatal("Connection error: ", connError)
}
log.Info("Successfully connected to sourceDb")
return nil
})
wgConnect.Go(func() error {
var err error
targetDb, err = connectWithTimeout(ctx, migrationConfig.TargetDbType, targetDbUrl, 20*time.Second)
if err != nil {
return err
}
log.Info("Successfully connected to targetDb")
return nil
})
if err := wgConnect.Wait(); err != nil {
log.Error("Connection error: ", err)
return
}
defer sourceDb.Close()
defer targetDb.Close()
var azureClient *azure.Client
if config.App.AzureStorage.Enabled {
var err error
azureClient, err = azure.NewClient(config.App.AzureStorage)
if err != nil {
log.Fatalf("Failed to create Azure storage client: %v", err)
for _, job := range migrationJobs {
log.Infof(">>> Processing job: %s.%s <<<", job.Schema, job.Table)
processMigrationJob(ctx, sourceDb, targetDb, job)
}
}
if *dryRun {
runDryRun(ctx, azureClient, sourceDb, migrationConfig.Jobs, migrationConfig.MaxParallelWorkers)
return
}
if *validate {
validationResults := validateJobs(ctx, sourceDb, targetDb, migrationConfig.Jobs, migrationConfig.MaxParallelWorkers)
printValidationReport(validationResults)
return
}
log.Info("=== Starting migration ===")
results := processMigrationJobs(ctx, sourceDb, targetDb, azureClient, migrationConfig.Jobs, migrationConfig.MaxParallelWorkers)
log.Info("=== RESUMEN DE MIGRACIÓN ===")
var totalProcessed, totalErrors int64
for _, res := range results {
status := "OK"
if res.Error != nil {
status = "FAILED"
log.Infof("[%s] Status: %s | Read: %d | Loaded: %d | Errors: %d | Time: %v | Error: %v", res.JobName, status, res.RowsRead, res.RowsLoaded, res.RowsFailed, res.Duration, res.Error)
} else {
log.Infof("[%s] Status: %s | Read: %d | Loaded: %d | Errors: %d | Time: %v", res.JobName, status, res.RowsRead, res.RowsLoaded, res.RowsFailed, res.Duration)
}
totalProcessed += res.RowsLoaded
if res.Error != nil {
totalErrors++
}
}
log.Infof("Migración terminada. Tablas: %d, Errores: %d, Filas totales: %d", len(results), totalErrors, totalProcessed)
totalDuration := time.Since(startTime)
// log.Infof("=== Migration completed successfully! ===")
log.Infof("=== Migration completed successfully! ===")
log.Infof("Total migration time: %v", totalDuration)
}
func processMigrationJobs(
ctx context.Context,
sourceDb dbwrapper.DbWrapper,
targetDb dbwrapper.DbWrapper,
azureClient *azure.Client,
jobs []config.Job,
maxParallelWorkers int,
) []models.JobResult {
if len(jobs) == 0 {
log.Info("No migration jobs configured")
return []models.JobResult{}
}
if maxParallelWorkers <= 0 {
maxParallelWorkers = 1
}
if maxParallelWorkers > len(jobs) {
maxParallelWorkers = len(jobs)
}
log.Infof("Starting migration with %d parallel worker(s)", maxParallelWorkers)
chJobResults := make(chan models.JobResult, len(jobs))
chJobs := make(chan config.Job, len(jobs))
var wgJobs sync.WaitGroup
sourceTableAnalyzer := newTableAnalyzer(sourceDb)
targetTableAnalyzer := newTableAnalyzer(targetDb)
extractor := extractors.NewExtractor(sourceDb)
loader := loaders.NewGenericLoader(targetDb)
for i := range maxParallelWorkers {
wgJobs.Go(func() {
for job := range chJobs {
log.Infof("[worker %d] >>> Processing job: %s.%s <<<", i, job.SourceTable.Schema, job.SourceTable.Table)
res := processMigrationJob(
ctx,
targetDb,
sourceTableAnalyzer,
targetTableAnalyzer,
extractor,
azureClient,
loader,
job,
sourceDb.GetDialect(),
targetDb.GetDialect(),
)
chJobResults <- res
}
})
}
for _, job := range jobs {
chJobs <- job
}
close(chJobs)
go func() {
wgJobs.Wait()
close(chJobResults)
}()
var finalResults []models.JobResult
for res := range chJobResults {
finalResults = append(finalResults, res)
}
return finalResults
}
func connectWithTimeout(ctx context.Context, dbType string, dbUrl string, timeout time.Duration) (dbwrapper.DbWrapper, error) {
localCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
sourceDb, err := dbwrapper.New(dbType)
if err != nil {
return nil, err
}
if err = sourceDb.Connect(localCtx, dbUrl); err != nil {
return nil, err
}
return sourceDb, nil
}

View File

@@ -1,11 +1,9 @@
package transformers
package main
import (
"encoding/binary"
"errors"
"time"
mssqlclrgeo "github.com/gaspardle/go-mssqlclrgeo"
)
func mssqlUuidToBigEndian(mssqlUuid []byte) ([]byte, error) {
@@ -64,29 +62,6 @@ func ensureUTC(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
}
func bigEndianToMssqlUuid(pgUuid []byte) ([]byte, error) {
if len(pgUuid) != 16 {
return nil, errors.New("Invalid uuid")
}
mssqlUuid := make([]byte, 16)
mssqlUuid[0], mssqlUuid[1], mssqlUuid[2], mssqlUuid[3] = pgUuid[3], pgUuid[2], pgUuid[1], pgUuid[0]
mssqlUuid[4], mssqlUuid[5] = pgUuid[5], pgUuid[4]
mssqlUuid[6], mssqlUuid[7] = pgUuid[7], pgUuid[6]
copy(mssqlUuid[8:], pgUuid[8:])
return mssqlUuid, nil
}
func ewkbToMssqlGeo(ewkb []byte, isGeography bool) ([]byte, error) {
if len(ewkb) < 5 {
return nil, errors.New("Invalid ewkb")
}
// mssqlclrgeo reads the SRID flag and bytes directly from EWKB,
// so no pre-processing needed — pass through as-is.
return mssqlclrgeo.WkbToUdtGeo(ewkb, isGeography)
}
func ToInt64(v any) (int64, bool) {
switch t := v.(type) {
case int:

View File

@@ -2,258 +2,135 @@ package main
import (
"context"
"fmt"
"database/sql"
"sync"
"sync/atomic"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/transformers"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
const jobErrorsChannelSize int = 100
func buildTruncateQuery(targetDbType, schema, table, truncateMethod string) string {
if truncateMethod == "DELETE" {
if targetDbType == "postgres" {
return fmt.Sprintf(`DELETE FROM "%s"."%s"`, schema, table)
}
return fmt.Sprintf(`DELETE FROM [%s].[%s]`, schema, table)
}
if targetDbType == "postgres" {
return fmt.Sprintf(`TRUNCATE TABLE "%s"."%s"`, schema, table)
}
return fmt.Sprintf(`TRUNCATE TABLE [%s].[%s]`, schema, table)
}
func processMigrationJob(
ctx context.Context,
targetDbWrapper dbwrapper.DbWrapper,
sourceTableAnalyzer etl.TableAnalyzer,
targetTableAnalyzer etl.TableAnalyzer,
extractor extractors.GenericExtractor,
azureClient *azure.Client,
loader loaders.GenericLoader,
job config.Job,
sourceDbType string,
targetDbType string,
) models.JobResult {
var transformer etl.Transformer
if sourceDbType == "postgres" {
transformer = transformers.NewPostgresTransformer(job.SourceTable)
} else {
transformer = transformers.NewMssqlTransformer(job.ToStorage, job.SourceTable, azureClient)
sourceDb *sql.DB,
targetDb *pgxpool.Pool,
job MigrationJob,
) {
jobStartTime := time.Now()
log.Infof("Starting migration job: %s.%s [PK: %s]", job.Schema, job.Table, job.PrimaryKey)
sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job)
if err != nil {
log.Fatal("Unexpected error: ", err)
}
localCtx, cancel := context.WithCancel(ctx)
logColumnTypes(sourceColTypes, "Source col types")
logColumnTypes(targetColTypes, "Target col types")
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
result := models.JobResult{
JobName: job.Name,
StartTime: time.Now(),
}
var wgQueryColumnTypes errgroup.Group
var sourceColTypes, targetColTypes []models.ColumnType
wgQueryColumnTypes.Go(func() error {
var err error
sourceColTypes, err = sourceTableAnalyzer.QueryColumnTypes(localCtx, job.SourceTable.TableInfo)
if err != nil {
return err
}
return nil
})
wgQueryColumnTypes.Go(func() error {
var err error
targetColTypes, err = targetTableAnalyzer.QueryColumnTypes(localCtx, job.TargetTable.TableInfo)
if err != nil {
return err
}
return nil
})
err := wgQueryColumnTypes.Wait()
if err != nil {
result.Error = err
return result
}
preSqlQueries := job.TargetTable.PreSQL
if job.TruncateTarget {
truncateQuery := buildTruncateQuery(targetDbType, job.TargetTable.Schema, job.TargetTable.Table, job.TruncateMethod)
preSqlQueries = append([]string{truncateQuery}, job.TargetTable.PreSQL...)
}
for _, query := range preSqlQueries {
if _, err := targetDbWrapper.Exec(localCtx, query); err != nil {
result.Error = err
return result
}
}
partitions, err := table_analyzers.PartitionRangeGenerator(
localCtx,
sourceTableAnalyzer,
job.SourceTable.TableInfo,
job.SourceTable.PrimaryKey,
job.PartitionCalculationStrategy,
job.RowsPerPartition,
job.Range,
)
batches, err := batchGeneratorMssql(jobCtx, sourceDb, job)
if err != nil {
log.Error("Unexpected error calculating batch ranges: ", err)
}
chJobErrors := make(chan custom_errors.JobError, jobErrorsChannelSize)
chPartitions := make(chan models.Partition)
chBatchesRaw := make(chan models.Batch, job.ExtractorQueueSize)
chBatchesTransformed := make(chan models.Batch, job.TransformerQueueSize)
chJobErrors := make(chan JobError, 50)
chBatches := make(chan Batch, QueueSize)
chExtractorErrors := make(chan ExtractorError, QueueSize)
chChunksRaw := make(chan Chunk, QueueSize)
chChunksTransformed := make(chan Chunk, QueueSize)
chLoadersErrors := make(chan LoaderError, QueueSize)
var wgActivePartitions, wgActiveBatches, wgExtractors, wgTransformers, wgLoaders sync.WaitGroup
var rowsRead, rowsLoaded, rowsFailed int64
var failedPartitionsCount, failedBatchesLoadCount int32
var wgActiveBatches sync.WaitGroup
var wgActiveChunks sync.WaitGroup
var wgExtractors sync.WaitGroup
var wgTransformers sync.WaitGroup
var wgLoaders sync.WaitGroup
go func() {
if err := custom_errors.JobErrorHandler(localCtx, chJobErrors); err != nil {
log.Error("Fatal error received from JobErrorHandler, canceling job... - ", err)
if err := jobErrorHandler(jobCtx, chJobErrors); err != nil {
cancel()
result.Error = err
}
}()
maxExtractors := min(job.MaxExtractors, len(partitions))
log.Infof("Starting %d extractor(s)... (%v)", maxExtractors, job.Name)
go extractorErrorHandler(jobCtx, chExtractorErrors, chBatches, chJobErrors, &wgActiveBatches)
go loaderErrorHandler(jobCtx, chLoadersErrors, chChunksTransformed, chJobErrors, &wgActiveChunks)
maxExtractors := min(NumExtractors, len(batches))
log.Infof("Starting %d extractors...", maxExtractors)
extractStartTime := time.Now()
for range maxExtractors {
wgExtractors.Go(func() {
extractor.Consume(
localCtx,
job.SourceTable,
sourceColTypes,
job.ExtractorBatchSize,
job.Retry,
chPartitions,
chBatchesRaw,
chJobErrors,
&wgActivePartitions,
&rowsRead,
&failedPartitionsCount,
job.SourceTable.FromJsonColumns,
)
extractFromMssql(jobCtx, sourceDb, job, sourceColTypes, ChunkSize, chBatches, chChunksRaw, chExtractorErrors, chJobErrors, &wgActiveBatches)
})
}
wgActivePartitions.Add(len(partitions))
wgActiveBatches.Add(len(batches))
go func() {
for _, batch := range partitions {
chPartitions <- batch
for _, batch := range batches {
chBatches <- batch
}
}()
log.Infof("Starting %d transformer(s)... (%v)", maxExtractors, job.Name)
log.Infof("Starting %d transformers...", maxExtractors)
transformStartTime := time.Now()
for range maxExtractors {
wgTransformers.Go(func() {
transformer.Consume(
localCtx,
sourceColTypes,
job.Retry,
job.TransformerBatchSize,
chBatchesRaw,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
)
transformRowsMssql(jobCtx, sourceColTypes, chChunksRaw, chChunksTransformed, chJobErrors, &wgActiveChunks)
})
}
log.Infof("Starting %d loader(s)... (%v)", job.MaxLoaders, job.Name)
log.Infof("Starting %d PostgreSQL loader(s)...", NumLoaders)
loadStartTime := time.Now()
for range job.MaxLoaders {
for range NumLoaders {
wgLoaders.Go(func() {
loader.Consume(
localCtx,
job.TargetTable,
targetColTypes,
job.Retry,
job.LoaderBatchSize,
chBatchesTransformed,
chJobErrors,
&wgActiveBatches,
&rowsLoaded,
&failedBatchesLoadCount,
)
loadRowsPostgres(jobCtx, targetDb, job, targetColTypes, chChunksTransformed, chLoadersErrors, chJobErrors, &wgActiveChunks)
})
}
go func() {
// log.Debugf("Waiting for goroutines (%v)", job.Name)
wgActivePartitions.Wait()
// log.Debugf("wgActivePartitions is empty (%v)", job.Name)
close(chPartitions)
// log.Debugf("chPartitions is closed (%v)", job.Name)
wgActiveBatches.Wait()
close(chBatches)
close(chExtractorErrors)
wgExtractors.Wait()
// log.Debugf("wgExtractors is empty (%v)", job.Name)
close(chBatchesRaw)
// log.Debugf("chBatchesRaw is closed (%v)", job.Name)
log.Infof("Extraction completed in %v", time.Since(extractStartTime))
close(chChunksRaw)
wgTransformers.Wait()
// log.Debugf("wgTransformers is empty (%v)", job.Name)
close(chBatchesTransformed)
// log.Debugf("chBatchesTransformed is closed (%v)", job.Name)
log.Infof("Transformation completed in %v", time.Since(transformStartTime))
wgActiveBatches.Wait()
// log.Debugf("wgActiveBatches is empty (%v)", job.Name)
wgActiveChunks.Wait()
close(chChunksTransformed)
close(chLoadersErrors)
wgLoaders.Wait()
// log.Debugf("wgLoaders is empty (%v)", job.Name)
log.Infof("Loading completed in %v", time.Since(loadStartTime))
cancel()
}()
for _, query := range job.TargetTable.PostSQL {
if _, err := targetDbWrapper.Exec(localCtx, query); err != nil {
result.Error = err
return result
<-jobCtx.Done()
log.Infof("Migration job completed. Total time: %v", time.Since(jobStartTime))
}
func logColumnTypes(columnTypes []ColumnType, label string) {
log.Debug(label)
for _, col := range columnTypes {
log.Debugf("%+v", col)
}
}
// log.Debugf("waiting for local context to be done (%v)", job.Name)
<-localCtx.Done()
// log.Debugf("local context done (%v)", job.Name)
if ctx.Err() != nil {
result.Error = ctx.Err()
func logSampleRow(job MigrationJob, columns []ColumnType, rowValues UnknownRowValues, tag string) {
log.Infof("[%s.%s] Sample row: (%s)", job.Schema, job.Table, tag)
for i, col := range columns {
log.Infof("%s (%T): %v", col.Name(), rowValues[i], rowValues[i])
}
result.Duration = time.Since(result.StartTime)
result.RowsRead = atomic.LoadInt64(&rowsRead)
result.RowsLoaded = atomic.LoadInt64(&rowsLoaded)
result.RowsFailed = atomic.LoadInt64(&rowsFailed)
if result.RowsRead != result.RowsLoaded {
result.Error = fmt.Errorf("Row count mismatch: extracted %d rows but loaded %d rows (failed: %d)", result.RowsRead, result.RowsLoaded, result.RowsFailed)
}
if result.RowsRead == 0 {
log.Warnf("No rows extracted from (%v)", job.Name)
}
return result
}

View File

@@ -0,0 +1,149 @@
package main
import (
"context"
"errors"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type transformerFunc func(any) (any, error)
type columnTransformPlan struct {
index int
fn transformerFunc
}
func transformRowsMssql(
ctx context.Context,
columns []ColumnType,
chChunksIn <-chan Chunk,
chChunksOut chan<- Chunk,
chJobErrorsOut chan<- JobError,
wgActiveChunks *sync.WaitGroup,
) {
transformationPlan := computeTransformationPlan(columns)
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case chunk, ok := <-chChunksIn:
if !ok {
return
}
if len(transformationPlan) == 0 {
select {
case chChunksOut <- chunk:
wgActiveChunks.Add(1)
continue
case <-ctx.Done():
return
}
}
chunkStartTime := time.Now()
err := processChunk(ctx, &chunk, transformationPlan)
if err != nil {
if errors.Is(err, ctx.Err()) {
return
}
select {
case chJobErrorsOut <- JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}:
case <-ctx.Done():
}
return
}
log.Infof("Transformed chunk %s: %d rows in %v", chunk.Id, len(chunk.Data), time.Since(chunkStartTime))
select {
case chChunksOut <- chunk:
case <-ctx.Done():
return
}
wgActiveChunks.Add(1)
}
}
}
func computeTransformationPlan(columns []ColumnType) []columnTransformPlan {
var plan []columnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "uniqueidentifier":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return mssqlUuidToBigEndian(b)
}
return v, nil
},
})
case "geometry", "geography":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return wkbToEwkbWithSrid(b, 4326)
}
return v, nil
},
})
case "datetime", "datetime2":
plan = append(plan, columnTransformPlan{
index: i,
fn: func(v any) (any, error) {
if t, ok := v.(time.Time); ok {
return ensureUTC(t), nil
}
return v, nil
},
})
}
}
return plan
}
const processChunkCtxCheck = 4096
func processChunk(ctx context.Context, chunk *Chunk, transformationPlan []columnTransformPlan) error {
for i, rowValues := range chunk.Data {
if i%processChunkCtxCheck == 0 {
if err := ctx.Err(); err != nil {
return err
}
}
for _, task := range transformationPlan {
val := rowValues[task.index]
if val == nil {
continue
}
transformed, err := task.fn(val)
if err != nil {
return err
}
rowValues[task.index] = transformed
}
}
return nil
}

View File

@@ -1,192 +0,0 @@
package main
import (
"context"
"database/sql"
"fmt"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
log "github.com/sirupsen/logrus"
)
type ValidationResult struct {
JobName string
SourceTable string
TargetTable string
SourceCount int64
TargetCount int64
Match bool
Error error
}
func countSourceRows(ctx context.Context, db dbwrapper.DbWrapper, job config.Job) (int64, error) {
schema := job.SourceTable.Schema
table := job.SourceTable.Table
hasRange := job.Range.Min != nil || job.Range.Max != nil
var (
query string
args []any
)
if hasRange && job.SourceTable.PrimaryKey != "" {
query = fmt.Sprintf("SELECT COUNT_BIG(*) FROM [%s].[%s] WHERE 1=1", schema, table)
if job.Range.Min != nil {
op := ">"
if job.Range.IsMinInclusive {
op = ">="
}
query += fmt.Sprintf(" AND [%s] %s @min", job.SourceTable.PrimaryKey, op)
args = append(args, sql.Named("min", *job.Range.Min))
}
if job.Range.Max != nil {
op := "<"
if job.Range.IsMaxInclusive {
op = "<="
}
query += fmt.Sprintf(" AND [%s] %s @max", job.SourceTable.PrimaryKey, op)
args = append(args, sql.Named("max", *job.Range.Max))
}
} else {
query = fmt.Sprintf("SELECT COUNT_BIG(*) FROM [%s].[%s]", schema, table)
}
var count int64
if err := db.QueryRow(ctx, query, args...).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func countTargetRows(ctx context.Context, db dbwrapper.DbWrapper, job config.Job) (int64, error) {
schema := job.TargetTable.Schema
table := job.TargetTable.Table
query := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, schema, table)
var count int64
if err := db.QueryRow(ctx, query).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func validateJob(ctx context.Context, sourceDb, targetDb dbwrapper.DbWrapper, job config.Job) ValidationResult {
result := ValidationResult{
JobName: job.Name,
SourceTable: fmt.Sprintf("[%s].[%s]", job.SourceTable.Schema, job.SourceTable.Table),
TargetTable: fmt.Sprintf(`"%s"."%s"`, job.TargetTable.Schema, job.TargetTable.Table),
}
var (
sourceErr, targetErr error
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
result.SourceCount, sourceErr = countSourceRows(ctx, sourceDb, job)
}()
go func() {
defer wg.Done()
result.TargetCount, targetErr = countTargetRows(ctx, targetDb, job)
}()
wg.Wait()
if sourceErr != nil {
result.Error = fmt.Errorf("source count failed: %w", sourceErr)
return result
}
if targetErr != nil {
result.Error = fmt.Errorf("target count failed: %w", targetErr)
return result
}
result.Match = result.SourceCount == result.TargetCount
return result
}
func validateJobs(
ctx context.Context,
sourceDb dbwrapper.DbWrapper,
targetDb dbwrapper.DbWrapper,
jobs []config.Job,
maxParallelWorkers int,
) []ValidationResult {
if len(jobs) == 0 {
return nil
}
if maxParallelWorkers <= 0 {
maxParallelWorkers = 1
}
if maxParallelWorkers > len(jobs) {
maxParallelWorkers = len(jobs)
}
chJobs := make(chan config.Job, len(jobs))
var mu sync.Mutex
var results []ValidationResult
var wg sync.WaitGroup
for range maxParallelWorkers {
wg.Go(func() {
for job := range chJobs {
res := validateJob(ctx, sourceDb, targetDb, job)
mu.Lock()
results = append(results, res)
mu.Unlock()
}
})
}
for _, job := range jobs {
chJobs <- job
}
close(chJobs)
wg.Wait()
return results
}
func printValidationReport(results []ValidationResult) {
log.Info("=== VALIDATION REPORT ===")
var totalMatch, totalMismatch, totalErrors int
for _, r := range results {
if r.Error != nil {
log.Errorf("[%s] ERROR: %v", r.JobName, r.Error)
totalErrors++
continue
}
if !r.Match {
totalMismatch++
diff := r.TargetCount - r.SourceCount
var diffStr string
if diff > 0 {
diffStr = fmt.Sprintf(" (target has %d extra rows)", diff)
} else {
diffStr = fmt.Sprintf(" (target is missing %d rows)", -diff)
}
log.Warnf("[%s] MISMATCH | Source %s: %d | Target %s: %d%s",
r.JobName,
r.SourceTable, r.SourceCount,
r.TargetTable, r.TargetCount,
diffStr,
)
} else {
totalMatch++
log.Infof("[%s] OK | Source %s: %d | Target %s: %d",
r.JobName,
r.SourceTable, r.SourceCount,
r.TargetTable, r.TargetCount,
)
}
}
log.Infof("=== Validation complete: %d OK, %d mismatches, %d errors ===", totalMatch, totalMismatch, totalErrors)
}

View File

@@ -1,35 +0,0 @@
max_parallel_workers: 4
source_db_type: postgres
target_db_type: sqlserver
defaults:
batches_per_partition: 4
max_extractors: 2
extractor_batch_size: 5000
extractor_queue_size: 8
max_transformers: 2
transformer_batch_size: 12500
transformer_queue_size: 8
max_loaders: 4
loader_batch_size: 25000
partition_calculation_strategy: EXACT
truncate_target: true
truncate_method: TRUNCATE
retry:
attempts: 3
base_delay_ms: 500
max_delay_ms: 10000
max_jitter_ms: 500
max_failed_partitions: 5
max_failed_batches_load: 5
jobs:
- name: cartografia_manzana_reverse
enabled: true
source:
schema: Cartografia
table: MANZANA
primary_key: GDB_ARCHIVE_OID
target:
schema: Cartografia
table: MANZANA

View File

@@ -1,35 +0,0 @@
max_parallel_workers: 4
source_db_type: postgres
target_db_type: sqlserver
defaults:
batches_per_partition: 4
max_extractors: 8
extractor_batch_size: 25000
extractor_queue_size: 32
max_transformers: 8
transformer_batch_size: 50000
transformer_queue_size: 32
max_loaders: 8
loader_batch_size: 50000
partition_calculation_strategy: EXACT
truncate_target: true
truncate_method: TRUNCATE
retry:
attempts: 3
base_delay_ms: 500
max_delay_ms: 10000
max_jitter_ms: 500
max_failed_partitions: 5
max_failed_batches_load: 5
jobs:
- name: cartografia_manzana_reverse
enabled: true
source:
schema: Cartografia
table: MANZANA
primary_key: GDB_ARCHIVE_OID
target:
schema: Cartografia
table: MANZANA

View File

@@ -1,27 +1,15 @@
max_parallel_workers: 4
source_db_type: sqlserver
target_db_type: postgres
max_parallel_workers: 2
defaults:
batches_per_partition: 4
max_extractors: 2
extractor_batch_size: 5000
extractor_queue_size: 8
max_transformers: 2
transformer_batch_size: 12500
transformer_queue_size: 8
max_loaders: 4
loader_batch_size: 25000
partition_calculation_strategy: EXACT # EXACT | ESTIMATION
max_extractors: 4
max_loaders: 8
queue_size: 8
chunk_size: 50000
chunks_per_batch: 10
truncate_target: true
truncate_method: TRUNCATE # TRUNCATE | DELETE
retry:
attempts: 3
base_delay_ms: 500
max_delay_ms: 10000
max_jitter_ms: 500
max_failed_partitions: 5
max_failed_batches_load: 5
jobs:
- name: cartografia_manzana
@@ -33,45 +21,26 @@ jobs:
target:
schema: Cartografia
table: MANZANA
max_extractors: 2 # overrides default config
max_loaders: 4 # overrides default config
queue_size: 4 # overrides default config
chunk_size: 25000 # overrides default config
chunks_per_batch: 8 # overrides default config
truncate_target: false # overrides default config
truncate_method: DELETE # overrides default config
retry:
attempts: 5 # overrides default config
pre_sql:
- "SELECT 1"
post_sql:
- "SELECT 2"
# - name: red_puerto
# enabled: true
# source:
# schema: Red
# table: PUERTO
# primary_key: ID_PUERTO
# from_json:
# - column: $node_id*
# field: id
# target:
# schema: Red
# table: PUERTO
# - name: infraestructura_site_holder__attach
# source:
# schema: Infraestructura
# table: SITE_HOLDER__ATTACH
# primary_key: GDB_ARCHIVE_OID
# target:
# schema: Infraestructura
# table: SITE_HOLDER__ATTACH
# to_storage:
# columns:
# - source: DATA
# target: FILE_URL
# mode: REFERENCE_ONLY
# prefix: Infraestructura/SITE_HOLDER__ATTACH
# batches_per_partition: 20
# max_extractors: 32
# extractor_batch_size: 1
# extractor_queue_size: 100
# max_transformers: 48
# transformer_batch_size: 500
# transformer_queue_size: 8
# max_loaders: 4
# loader_batch_size: 500
# retry:
# attempts: 5
# base_delay_ms: 1000
# max_delay_ms: 15000
# max_jitter_ms: 500
- name: red_puerto
enabled: true
source:
schema: Red
table: PUERTO
primary_key: ID_PUERTO
target:
schema: Red
table: PUERTO

2
docker/.gitignore vendored
View File

@@ -1,2 +0,0 @@
data/**/*
compose.override.yml

View File

@@ -1,50 +0,0 @@
name: db-migration
services:
azurite:
image: mcr.microsoft.com/azure-storage/azurite:3.35.0
container_name: azurite
restart: unless-stopped
ports:
- 8880:10000
- 8881:10001
- 8882:10002
volumes:
- ./data/azurite:/data
command: 'azurite --blobHost 0.0.0.0 --queueHost 0.0.0.0 --tableHost 0.0.0.0 --location /data --skipApiVersionCheck'
profiles:
- storage
- target
mssql:
image: mcr.microsoft.com/mssql/server:2022-latest
restart: unless-stopped
environment:
ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: SecurePassword123
MSSQL_PID: Developer
MSSQL_MEMORY_LIMIT_MB: 8192
ports:
- 8883:1433
volumes:
- ./data/mssql:/var/opt/mssql
profiles:
- mssql
- source
- db
postgres:
image: postgis/postgis:16-3.4
restart: unless-stopped
environment:
POSTGRES_DB: test_db
POSTGRES_USER: postgres
POSTGRES_PASSWORD: SecurePassword123
ports:
- 8884:5432
volumes:
- ./data/postgres:/var/lib/postgresql/data
profiles:
- postgres
- target
- db
shm_size: '1gb'

16
go.mod
View File

@@ -1,35 +1,27 @@
module git.ksdemosapps.com/kylesoda/go-migrate
go 1.26
go 1.25.7
require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4
github.com/gaspardle/go-mssqlclrgeo v0.0.0-20160129143314-97ceabf987a4
github.com/goccy/go-yaml v1.19.2
github.com/google/uuid v1.6.0
github.com/ilyakaznacheev/cleanenv v1.5.0
github.com/jackc/pgx/v5 v5.9.1
github.com/joho/godotenv v1.5.1
github.com/microsoft/go-mssqldb v1.9.8
github.com/sirupsen/logrus v1.9.4
github.com/twpayne/go-geom v1.6.1
golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect
)

21
go.sum
View File

@@ -4,19 +4,12 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpz
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0 h1:E4MgwLBGeVB5f2MdcIVD3ELVAWpr+WD6MUe1i+tM/PA=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0/go.mod h1:Y2b/1clN4zsAoUd/pgNAQHjLDnTis/6ROkUfyob6psM=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY=
@@ -28,6 +21,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gaspardle/go-mssqlclrgeo v0.0.0-20160129143314-97ceabf987a4 h1:4vH4+3zfwZTqoJEFw7DsTaH1V8jgVwnyeDvNi2TxzAc=
github.com/gaspardle/go-mssqlclrgeo v0.0.0-20160129143314-97ceabf987a4/go.mod h1:jlB0I5BIfcJBGdV6rRGPthSBfeY86RGkSAwcsldbHJc=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
@@ -38,8 +33,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/ilyakaznacheev/cleanenv v1.5.0 h1:0VNZXggJE2OYdXE87bfSSwGxeiGt9moSR2lOrsHHvr4=
github.com/ilyakaznacheev/cleanenv v1.5.0/go.mod h1:a5aDzaJrLCQZsazHol1w8InnDcOX0OColm64SlIi6gk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -50,10 +43,6 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/microsoft/go-mssqldb v1.9.8 h1:d4IFMvF/o+HdpXUqbBfzHvn/NlFA75YGcfHUUvDFJEM=
@@ -62,8 +51,6 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
@@ -86,10 +73,6 @@ golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 h1:slmdOY3vp8a7KQbHkL+FLbvbkgMqmXojpFUO/jENuqQ=
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3/go.mod h1:oVgVk4OWVDi43qWBEyGhXgYxt7+ED4iYNpTngSLX2Iw=

View File

@@ -1,98 +0,0 @@
package azure
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
)
var (
ErrInvalidConnectionString = errors.New("invalid connection string")
ErrContainerNotFound = errors.New("container not found")
ErrBlobNotFound = errors.New("blob not found")
ErrInvalidInput = errors.New("invalid input parameters")
)
type Client struct {
client *azblob.Client
azureStorageConfig config.AzureStorageConfig
}
func NewClient(azureStorageConfig config.AzureStorageConfig) (*Client, error) {
protocol := "https"
if !azureStorageConfig.UseHTTPS {
protocol = "http"
}
blobEndpoint, _ := url.JoinPath(azureStorageConfig.ServiceURL, azureStorageConfig.AccountName)
connStr := fmt.Sprintf("DefaultEndpointsProtocol=%s;AccountName=%s;AccountKey=%s;BlobEndpoint=%s;",
protocol, azureStorageConfig.AccountName, azureStorageConfig.AccountKey, blobEndpoint)
client, err := azblob.NewClientFromConnectionString(connStr, nil)
if err != nil {
return nil, fmt.Errorf("creating azure storage client: %w", err)
}
return &Client{
client: client,
azureStorageConfig: azureStorageConfig,
}, nil
}
func (c *Client) CreateContainer(ctx context.Context, containerName string) error {
if containerName == "" {
return ErrInvalidInput
}
_, err := c.client.CreateContainer(ctx, containerName, nil)
if err != nil {
return fmt.Errorf("creating container %s: %w", containerName, err)
}
return nil
}
func (c *Client) UploadBuffer(ctx context.Context, containerName, blobPath string, buffer []byte) error {
if containerName == "" || blobPath == "" || buffer == nil {
return ErrInvalidInput
}
_, err := c.client.UploadBuffer(ctx, containerName, blobPath, buffer, nil)
if err != nil {
return fmt.Errorf("uploading blob %s: %w", blobPath, err)
}
return nil
}
func (c *Client) Ping(ctx context.Context) error {
pager := c.client.NewListBlobsFlatPager(c.azureStorageConfig.Container, nil)
_, err := pager.NextPage(ctx)
if err != nil {
return fmt.Errorf("storage access check failed: %w", err)
}
return nil
}
func (c *Client) UploadAndGetURL(ctx context.Context, blobPath string, buffer []byte) (string, error) {
if blobPath == "" || buffer == nil {
return "", ErrInvalidInput
}
fullPath := path.Join(c.azureStorageConfig.Prefix, blobPath)
contentType := http.DetectContentType(buffer)
opts := &azblob.UploadBufferOptions{
HTTPHeaders: &blob.HTTPHeaders{BlobContentType: &contentType},
}
if _, err := c.client.UploadBuffer(ctx, c.azureStorageConfig.Container, fullPath, buffer, opts); err != nil {
return "", fmt.Errorf("uploading blob %s: %w", fullPath, err)
}
return fullPath, nil
}

View File

@@ -1,130 +1,55 @@
package config
import (
"fmt"
"maps"
"net/url"
"os"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
log "github.com/sirupsen/logrus"
)
type AzureStorageConfig struct {
AccountName string `env:"AZ_ACCOUNT_NAME"`
Container string `env:"AZ_CONTAINER"`
AccountKey string `env:"AZ_ACCOUNT_KEY"`
UseHTTPS bool `env:"AZ_USE_HTTPS" env-default:"true"`
ServiceURL string `env:"AZ_SERVICE_URL"`
Prefix string `env:"AZ_PREFIX"`
Enabled bool `env:"AZ_STORAGE_ENABLED"`
}
type appConfig struct {
SourceDbUrl string `env:"SOURCE_DB_URL"`
SourceDbHost string `env:"SOURCE_DB_HOST"`
SourceDbPort string `env:"SOURCE_DB_PORT"`
SourceDbName string `env:"SOURCE_DB_NAME"`
SourceDbUser string `env:"SOURCE_DB_USER"`
SourceDbPwd string `env:"SOURCE_DB_PWD"`
SourceDbOptions string `env:"SOURCE_DB_OPTIONS"`
TargetDbUrl string `env:"TARGET_DB_URL"`
TargetDbHost string `env:"TARGET_DB_HOST"`
TargetDbPort string `env:"TARGET_DB_PORT"`
TargetDbName string `env:"TARGET_DB_NAME"`
TargetDbUser string `env:"TARGET_DB_USER"`
TargetDbPwd string `env:"TARGET_DB_PWD"`
TargetDbOptions string `env:"TARGET_DB_OPTIONS"`
LogLevel string `env:"LOG_LEVEL" env-default:"INFO"`
AzureStorage AzureStorageConfig
SourceDbUrl string
SourceDbType string
TargetDbUrl string
TargetDbType string
}
func (c *appConfig) ResolveSourceDbUrl(dbType string) (string, error) {
if c.SourceDbUrl != "" {
return c.SourceDbUrl, nil
}
u, err := buildDbUrl(dbType, c.SourceDbHost, c.SourceDbPort, c.SourceDbName, c.SourceDbUser, c.SourceDbPwd, c.SourceDbOptions)
func loadEnv() {
err := godotenv.Load()
if err != nil {
return "", fmt.Errorf("source DB: %w", err)
}
return u, nil
}
func (c *appConfig) ResolveTargetDbUrl(dbType string) (string, error) {
if c.TargetDbUrl != "" {
return c.TargetDbUrl, nil
}
u, err := buildDbUrl(dbType, c.TargetDbHost, c.TargetDbPort, c.TargetDbName, c.TargetDbUser, c.TargetDbPwd, c.TargetDbOptions)
if err != nil {
return "", fmt.Errorf("target DB: %w", err)
}
return u, nil
}
func buildDbUrl(dbType, host, port, name, user, pwd, options string) (string, error) {
if host == "" {
return "", fmt.Errorf("DB_HOST is required when DB_URL is not set")
}
if name == "" {
return "", fmt.Errorf("DB_NAME is required when DB_URL is not set")
}
if user == "" {
return "", fmt.Errorf("DB_USER is required when DB_URL is not set")
}
switch dbType {
case "sqlserver":
if port == "" {
port = "1433"
}
q := url.Values{}
if options != "" {
extra, err := url.ParseQuery(options)
if err != nil {
return "", fmt.Errorf("invalid DB_OPTIONS: %w", err)
}
maps.Copy(q, extra)
}
q.Set("database", name)
u := &url.URL{
Scheme: "sqlserver",
Host: host + ":" + port,
User: url.UserPassword(user, pwd),
RawQuery: q.Encode(),
}
return u.String(), nil
case "postgres":
if port == "" {
port = "5432"
}
u := &url.URL{
Scheme: "postgres",
Host: host + ":" + port,
User: url.UserPassword(user, pwd),
Path: "/" + name,
RawQuery: options,
}
return u.String(), nil
default:
return "", fmt.Errorf("unknown db type %q — cannot build URL from individual components", dbType)
log.Warn("Warning: could not load .env file")
}
}
func getAppConfig() appConfig {
var cfg appConfig
loadEnv()
err := cleanenv.ReadConfig(".env", &cfg)
if err != nil {
log.Warn("Could not load .env file")
sourceDbUrl := os.Getenv("SOURCE_DB_URL")
if sourceDbUrl == "" {
log.Fatal("SOURCE_DB_URL environment variable not set")
}
err = cleanenv.ReadEnv(&cfg)
if err != nil {
log.Fatalf("Error al cargar variables: %v", err)
sourceDbType := os.Getenv("SOURCE_DB_TYPE")
if sourceDbType == "" {
log.Fatal("SOURCE_DB_TYPE environment variable not set")
}
return cfg
targetDbUrl := os.Getenv("TARGET_DB_URL")
if targetDbUrl == "" {
log.Fatal("TARGET_DB_URL environment variable not set")
}
targetDbType := os.Getenv("TARGET_DB_TYPE")
if targetDbType == "" {
log.Fatal("TARGET_DB_TYPE environment variable not set")
}
return appConfig{
SourceDbUrl: sourceDbUrl,
SourceDbType: sourceDbType,
TargetDbUrl: targetDbUrl,
TargetDbType: targetDbType,
}
}
var App appConfig = getAppConfig()

View File

@@ -1,163 +0,0 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type RetryConfig struct {
Attempts int `yaml:"attempts"`
BaseDelayMs int `yaml:"base_delay_ms"`
MaxDelayMs int `yaml:"max_delay_ms"`
MaxJitterMs int `yaml:"max_jitter_ms"`
MaxFailedPartitions int `yaml:"max_failed_partitions"`
MaxFailedBatchesLoad int `yaml:"max_failed_batches_load"`
}
type ToStorageColumnConfig struct {
Source string `yaml:"source"`
Target string `yaml:"target"`
Mode string `yaml:"mode"`
Prefix string `yaml:"prefix"`
}
type ToStorageConfig struct {
Columns []ToStorageColumnConfig `yaml:"columns"`
}
type JobConfig struct {
BatchesPerPartition int `yaml:"batches_per_partition"`
MaxExtractors int `yaml:"max_extractors"`
ExtractorBatchSize int `yaml:"extractor_batch_size"`
ExtractorQueueSize int `yaml:"extractor_queue_size"`
MaxTransformers int `yaml:"max_transformers"`
TransformerBatchSize int `yaml:"transformer_batch_size"`
TransformerQueueSize int `yaml:"transformer_queue_size"`
MaxLoaders int `yaml:"max_loaders"`
LoaderBatchSize int `yaml:"loader_batch_size"`
PartitionCalculationStrategy string `yaml:"partition_calculation_strategy"`
TruncateTarget bool `yaml:"truncate_target"`
TruncateMethod string `yaml:"truncate_method"`
Retry RetryConfig `yaml:"retry"`
RowsPerPartition int64
ToStorage ToStorageConfig `yaml:"to_storage"`
}
type FromJsonItem struct {
Column string `yaml:"column"`
Field string `yaml:"field"`
}
type TableInfo struct {
Schema string `yaml:"schema"`
Table string `yaml:"table"`
}
type SourceTableInfo struct {
TableInfo `yaml:",inline"`
PrimaryKey string `yaml:"primary_key"`
FromJsonColumns []FromJsonItem `yaml:"from_json"`
}
type TargetTableInfo struct {
TableInfo `yaml:",inline"`
PreSQL []string `yaml:"pre_sql"`
PostSQL []string `yaml:"post_sql"`
}
type RangeConfig struct {
Min *int64 `yaml:"min"`
Max *int64 `yaml:"max"`
IsMinInclusive bool `yaml:"is_min_inclusive"`
IsMaxInclusive bool `yaml:"is_max_inclusive"`
}
type Job struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
SourceTable SourceTableInfo `yaml:"source"`
TargetTable TargetTableInfo `yaml:"target"`
JobConfig `yaml:",inline"`
Range RangeConfig `yaml:"range"`
}
type MigrationConfig struct {
MaxParallelWorkers int `yaml:"max_parallel_workers"`
SourceDbType string `yaml:"source_db_type"`
TargetDbType string `yaml:"target_db_type"`
Defaults JobConfig `yaml:"defaults"`
Jobs []Job `yaml:"jobs"`
}
type rawConfig struct {
MaxParallelWorkers int `yaml:"max_parallel_workers"`
SourceDbType string `yaml:"source_db_type"`
TargetDbType string `yaml:"target_db_type"`
Defaults JobConfig `yaml:"defaults"`
Jobs []yaml.Node `yaml:"jobs"`
}
func (c *MigrationConfig) UnmarshalYAML(value *yaml.Node) error {
var raw rawConfig
if err := value.Decode(&raw); err != nil {
return err
}
c.MaxParallelWorkers = raw.MaxParallelWorkers
c.Defaults = raw.Defaults
c.SourceDbType = raw.SourceDbType
c.TargetDbType = raw.TargetDbType
c.Defaults.RowsPerPartition = int64(raw.Defaults.ExtractorBatchSize * raw.Defaults.BatchesPerPartition)
for _, node := range raw.Jobs {
job := Job{
JobConfig: raw.Defaults,
}
if err := node.Decode(&job); err != nil {
return err
}
job.RowsPerPartition = int64(job.ExtractorBatchSize * job.BatchesPerPartition)
c.Jobs = append(c.Jobs, job)
}
return nil
}
const defaultConfigFileName string = "config.yaml"
func filenamesOrDefault(filenames []string) []string {
if len(filenames) == 0 {
return []string{defaultConfigFileName}
}
return filenames
}
func ReadMigrationConfig(filenames ...string) (MigrationConfig, error) {
filenames = filenamesOrDefault(filenames)
var data []byte
var err error
for _, filename := range filenames {
data, err = os.ReadFile(filename)
if err != nil {
continue
}
break
}
if err != nil {
return MigrationConfig{}, fmt.Errorf("Error reading config file: %v", err)
}
var config MigrationConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return MigrationConfig{}, fmt.Errorf("Error parsing config file: %v", err)
}
return config, nil
}

View File

@@ -1,18 +0,0 @@
package convert
func ToInt64(v any) (int64, bool) {
switch t := v.(type) {
case int:
return int64(t), true
case int8:
return int64(t), true
case int16:
return int64(t), true
case int32:
return int64(t), true
case int64:
return int64(t), true
default:
return 0, false
}
}

View File

@@ -1,41 +0,0 @@
package custom_errors
import (
"math/rand"
"time"
)
func ComputeBackoffDelay(retryCounter int, baseDelayMs int, maxDelayMs int, maxJitterMs int) time.Duration {
if retryCounter < 0 {
retryCounter = 0
}
delay := max(time.Duration(baseDelayMs)*time.Millisecond, 0)
maxDelay := time.Duration(maxDelayMs) * time.Millisecond
for i := 0; i < retryCounter; i++ {
if maxDelayMs > 0 && delay >= maxDelay {
delay = maxDelay
break
}
if delay == 0 {
break
}
delay *= 2
}
if maxDelayMs > 0 && delay > maxDelay {
delay = maxDelay
}
if maxJitterMs > 0 {
jitter := time.Duration(rand.Intn(maxJitterMs+1)) * time.Millisecond
delay += jitter
}
if delay < 0 {
delay = 0
}
return delay
}

View File

@@ -1,25 +0,0 @@
package custom_errors
import (
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type ExtractorError struct {
Partition models.Partition
LastId int64
HasLastId bool
Msg string
}
func (e *ExtractorError) Error() string {
return e.Msg
}
type LoaderError struct {
Batch models.Batch
Msg string
}
func (e *LoaderError) Error() string {
return e.Msg
}

View File

@@ -1,7 +0,0 @@
package db_dialects
const (
SqlServer string = "sqlserver"
Postgres string = "postgres"
Null string = "null"
)

View File

@@ -1,19 +0,0 @@
package dbwrapper
import "fmt"
type Factory func() DbWrapper
var drivers = make(map[string]Factory)
func Register(name string, factory Factory) {
drivers[name] = factory
}
func New(driverType string) (DbWrapper, error) {
factory, ok := drivers[driverType]
if !ok {
return nil, fmt.Errorf("driver not yet supported: %s", driverType)
}
return factory(), nil
}

View File

@@ -1,317 +0,0 @@
package dbwrapper
import (
"context"
"database/sql"
"fmt"
"strings"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db_dialects"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
mssql "github.com/microsoft/go-mssqldb"
)
func init() {
Register(dbdialects.SqlServer, func() DbWrapper {
return &mssqlDbWrapper{dialect: dbdialects.SqlServer}
})
}
type mssqlRowResult struct {
row *sql.Row
}
func (mr *mssqlRowResult) Scan(dest ...any) error {
return mr.row.Scan(dest...)
}
type mssqlRowsResult struct {
columns []string
rows *sql.Rows
}
func (mr *mssqlRowsResult) Close() error {
return mr.rows.Close()
}
func (mr *mssqlRowsResult) Columns() ([]string, error) {
if mr.columns != nil {
return mr.columns, nil
}
return mr.rows.Columns()
}
func (mr *mssqlRowsResult) Err() error {
return mr.rows.Err()
}
func (mr *mssqlRowsResult) Next() bool {
return mr.rows.Next()
}
func (mr *mssqlRowsResult) Scan(dest ...any) error {
return mr.rows.Scan(dest...)
}
func (mr *mssqlRowsResult) Values() ([]any, error) {
columns, err := mr.Columns()
if err != nil {
return nil, err
}
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
if err := mr.rows.Scan(scanArgs...); err != nil {
return nil, err
}
return rowValues, nil
}
type mssqlDbWrapper struct {
db *sql.DB
dialect string
}
func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error {
db, err := sql.Open("sqlserver", dbUrl)
if err != nil {
return err
}
if err := db.PingContext(ctx); err != nil {
if err := db.Close(); err != nil {
return err
}
return err
}
mw.db = db
return nil
}
func (mw *mssqlDbWrapper) Close() error {
return mw.db.Close()
}
func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, execErr := mw.db.ExecContext(ctx, query, args...)
if execErr != nil {
return ExecResult{}, execErr
}
affectedRows, err := result.RowsAffected()
if err != nil {
return ExecResult{}, err
}
return ExecResult{AffectedRows: affectedRows}, nil
}
func (mw *mssqlDbWrapper) GetDialect() string {
return mw.dialect
}
func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
rows, err := mw.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &mssqlRowsResult{columns: nil, rows: rows}, nil
}
func (mw *mssqlDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult {
row := mw.db.QueryRowContext(ctx, query, args...)
return &mssqlRowResult{row: row}
}
func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
tx, err := mw.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
fullTableName := fmt.Sprintf("[%s].[%s]", schema, table)
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...))
if err != nil {
tx.Rollback()
return 0, err
}
for _, row := range rows {
_, err = stmt.ExecContext(ctx, row...)
if err != nil {
stmt.Close()
tx.Rollback()
return 0, err
}
}
result, err := stmt.ExecContext(ctx)
if err != nil {
stmt.Close()
tx.Rollback()
return 0, err
}
if err := stmt.Close(); err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Commit(); err != nil {
return 0, err
}
rowsAffected, raErr := result.RowsAffected()
if raErr != nil {
return 0, nil
}
return rowsAffected, nil
}
func buildExtractQueryMssql(q ExtractionQuery) (string, error) {
var sbQuery strings.Builder
sbQuery.WriteString("SELECT ")
hasRegularColumns := len(q.Columns) > 0
hasJsonColumns := len(q.FromJsonColumns) > 0
resolvedJson := make(map[string][]config.FromJsonItem, len(q.FromJsonColumns))
if hasJsonColumns {
for _, jsonConfig := range q.FromJsonColumns {
actualColumnName, err := findColumnByPattern(q.Columns, jsonConfig.Column)
if err != nil {
return "", err
}
resolvedJson[actualColumnName] = append(resolvedJson[actualColumnName], jsonConfig)
}
}
selectParts := make([]string, 0, len(q.Columns)+len(q.FromJsonColumns))
if hasRegularColumns {
for _, col := range q.Columns {
jsonConfigs, isJsonColumn := resolvedJson[col.Name()]
if isJsonColumn {
for _, jsonConfig := range jsonConfigs {
jsonPath := buildJsonPathMssql(jsonConfig.Field)
jsonExpr := fmt.Sprintf("JSON_VALUE([%s], '%s') AS [%s]", col.Name(), jsonPath, col.Name())
selectParts = append(selectParts, jsonExpr)
}
continue
}
colExpr := fmt.Sprintf("[%s]", col.Name())
switch col.Type() {
case "GEOMETRY":
colExpr = fmt.Sprintf("[%s].STAsBinary() AS [%s]", col.Name(), col.Name())
}
selectParts = append(selectParts, colExpr)
}
} else if !hasJsonColumns {
selectParts = append(selectParts, "*")
}
for i, part := range selectParts {
sbQuery.WriteString(part)
if i < len(selectParts)-1 {
sbQuery.WriteString(", ")
}
}
fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", q.Schema, q.Table)
if q.LowerLimit.IsValid || q.UpperLimit.IsValid {
sbQuery.WriteString(" WHERE ")
if q.LowerLimit.IsValid {
fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey)
if q.LowerLimit.IsInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
sbQuery.WriteString(" @min")
}
if q.LowerLimit.IsValid && q.UpperLimit.IsValid {
sbQuery.WriteString(" AND ")
}
if q.UpperLimit.IsValid {
fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey)
if q.UpperLimit.IsInclusive {
sbQuery.WriteString(" <=")
} else {
sbQuery.WriteString(" <")
}
sbQuery.WriteString(" @max")
}
}
fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", q.PrimaryKey)
return sbQuery.String(), nil
}
func findColumnByPattern(columns []models.ColumnType, pattern string) (string, error) {
if pattern == "" {
return "", fmt.Errorf("column pattern cannot be empty")
}
if before, ok := strings.CutSuffix(pattern, "*"); ok {
prefix := before
for _, col := range columns {
if strings.HasPrefix(col.Name(), prefix) {
return col.Name(), nil
}
}
return "", fmt.Errorf("no column found matching pattern '%s'", pattern)
}
for _, col := range columns {
if col.Name() == pattern {
return col.Name(), nil
}
}
return "", fmt.Errorf("column '%s' not found in table columns", pattern)
}
func (mw *mssqlDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) {
queryString, err := buildExtractQueryMssql(q)
if err != nil {
return nil, err
}
// logrus.Debugf("Query: %s", queryString)
var queryArgs []any
if q.LowerLimit.IsValid {
queryArgs = append(queryArgs, sql.Named("min", q.LowerLimit.Value))
}
if q.UpperLimit.IsValid {
queryArgs = append(queryArgs, sql.Named("max", q.UpperLimit.Value))
}
return mw.Query(ctx, queryString, queryArgs...)
}
func buildJsonPathMssql(field string) string {
if len(field) > 0 && field[0] == '.' {
field = field[1:]
}
return "$." + field
}

View File

@@ -1,396 +0,0 @@
package dbwrapper
import (
"strings"
"testing"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
func TestBuildExtractQueryMssql_NoJsonColumns(t *testing.T) {
q := ExtractionQuery{
Schema: "dbo",
Table: "Users",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Name", true, false, "VARCHAR", "varchar", "VARCHAR", true, 255, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if !strings.Contains(query, "SELECT [ID], [Name]") {
t.Errorf("Expected columns in query, got: %s", query)
}
if !strings.Contains(query, "FROM [dbo].[Users]") {
t.Errorf("Expected FROM clause, got: %s", query)
}
if !strings.Contains(query, "ORDER BY [ID] ASC") {
t.Errorf("Expected ORDER BY clause, got: %s", query)
}
}
func TestBuildExtractQueryMssql_WithJsonColumns_ExactColumnMatch(t *testing.T) {
// Test that the actual column name is used as alias, not a generated one
q := ExtractionQuery{
Schema: "dbo",
Table: "Events",
PrimaryKey: "EventID",
Columns: []models.ColumnType{
models.NewColumnType("EventID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("EventData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "EventData", Field: ".userId"},
{Column: "EventData", Field: ".timestamp"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if !strings.HasPrefix(query, "SELECT [EventID], JSON_VALUE([EventData], '$.userId') AS [EventData], JSON_VALUE([EventData], '$.timestamp') AS [EventData]") {
t.Errorf("Expected JSON columns to replace EventData in-order, got: %s", query)
}
if strings.Contains(query, "SELECT [EventID], [EventData]") {
t.Errorf("Expected EventData to be replaced by JSON extraction, got: %s", query)
}
// Alias should be exactly "EventData", not "EventData_userId"
if !strings.Contains(query, "JSON_VALUE([EventData], '$.userId') AS [EventData]") {
t.Errorf("Expected JSON alias to be [EventData], got: %s", query)
}
if !strings.Contains(query, "JSON_VALUE([EventData], '$.timestamp') AS [EventData]") {
t.Errorf("Expected JSON alias to be [EventData], got: %s", query)
}
// Should have comma separating them
if !strings.Contains(query, "JSON_VALUE([EventData], '$.userId') AS [EventData], JSON_VALUE([EventData], '$.timestamp') AS [EventData]") {
t.Errorf("Expected comma-separated JSON values, got: %s", query)
}
}
func TestBuildExtractQueryMssql_WithWildcardPattern(t *testing.T) {
// Test that wildcard pattern matching finds the correct column
q := ExtractionQuery{
Schema: "dbo",
Table: "Events",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("NodeMetadata", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "NodeMeta*", Field: ".id"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
// Should find "NodeMetadata" from pattern "NodeMeta*" and use it as alias
if !strings.Contains(query, "JSON_VALUE([NodeMetadata], '$.id') AS [NodeMetadata]") {
t.Errorf("Expected to find and use NodeMetadata column by pattern, got: %s", query)
}
if strings.Contains(query, "SELECT [ID], [NodeMetadata]") {
t.Errorf("Expected NodeMetadata to be replaced by JSON extraction, got: %s", query)
}
}
func TestBuildExtractQueryMssql_ColumnNotFound_Error(t *testing.T) {
// Test that an error is returned when column is not found
q := ExtractionQuery{
Schema: "dbo",
Table: "Events",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "NonExistentColumn", Field: ".id"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err == nil {
t.Fatalf("Expected error for missing column, got no error. Query: %s", query)
}
if !strings.Contains(err.Error(), "NonExistentColumn") {
t.Errorf("Expected error message to contain column name, got: %v", err)
}
}
func TestBuildExtractQueryMssql_WildcardPatternNotMatched_Error(t *testing.T) {
// Test that an error is returned when wildcard pattern doesn't match any column
q := ExtractionQuery{
Schema: "dbo",
Table: "Events",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("EventData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "NonMatching*", Field: ".id"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err == nil {
t.Fatalf("Expected error for non-matching wildcard pattern, got no error. Query: %s", query)
}
if !strings.Contains(err.Error(), "NonMatching*") {
t.Errorf("Expected error message to contain pattern, got: %v", err)
}
}
func TestBuildExtractQueryMssql_NestedJsonFields(t *testing.T) {
q := ExtractionQuery{
Schema: "dbo",
Table: "Data",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("NodeData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "NodeData", Field: ".user.name"},
{Column: "NodeData", Field: ".user.email"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if !strings.Contains(query, "JSON_VALUE([NodeData], '$.user.name') AS [NodeData]") {
t.Errorf("Expected nested JSON path for user.name, got: %s", query)
}
if !strings.Contains(query, "JSON_VALUE([NodeData], '$.user.email') AS [NodeData]") {
t.Errorf("Expected nested JSON path for user.email, got: %s", query)
}
if strings.Contains(query, "SELECT [ID], [NodeData]") {
t.Errorf("Expected NodeData to be replaced by JSON extraction, got: %s", query)
}
}
func TestBuildExtractQueryMssql_WithRangeLimits(t *testing.T) {
q := ExtractionQuery{
Schema: "dbo",
Table: "Products",
PrimaryKey: "ProductID",
Columns: []models.ColumnType{
models.NewColumnType("ProductID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Details", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "Details", Field: ".price"},
},
LowerLimit: ExtractorQueryLimit{IsValid: true, IsInclusive: true, Value: 100},
UpperLimit: ExtractorQueryLimit{IsValid: true, IsInclusive: false, Value: 500},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if !strings.Contains(query, "WHERE [ProductID] >= @min") {
t.Errorf("Expected WHERE clause with >=, got: %s", query)
}
if !strings.Contains(query, "[ProductID] < @max") {
t.Errorf("Expected upper limit with <, got: %s", query)
}
if !strings.Contains(query, "JSON_VALUE([Details], '$.price') AS [Details]") {
t.Errorf("Expected JSON_VALUE for Details, got: %s", query)
}
if strings.Contains(query, "SELECT [ProductID], [Details]") {
t.Errorf("Expected Details to be replaced by JSON extraction, got: %s", query)
}
}
func TestBuildJsonPathMssql(t *testing.T) {
tests := []struct {
input string
expected string
}{
{".id", "$.id"},
{"id", "$.id"},
{".user.name", "$.user.name"},
{"user.name", "$.user.name"},
{".location.coordinates.lat", "$.location.coordinates.lat"},
{"", "$."},
}
for _, tt := range tests {
result := buildJsonPathMssql(tt.input)
if result != tt.expected {
t.Errorf("buildJsonPathMssql(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestFindColumnByPattern_ExactMatch(t *testing.T) {
columns := []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Metadata", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
models.NewColumnType("EventData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
}
result, err := findColumnByPattern(columns, "Metadata")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result != "Metadata" {
t.Errorf("Expected 'Metadata', got '%s'", result)
}
}
func TestFindColumnByPattern_WildcardMatch(t *testing.T) {
columns := []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("NodeMetadata", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
models.NewColumnType("EventData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
}
result, err := findColumnByPattern(columns, "NodeMeta*")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result != "NodeMetadata" {
t.Errorf("Expected 'NodeMetadata', got '%s'", result)
}
}
func TestFindColumnByPattern_NotFound(t *testing.T) {
columns := []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Metadata", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
}
result, err := findColumnByPattern(columns, "NonExistent")
if err == nil {
t.Fatalf("Expected error, got no error. Result: %s", result)
}
if !strings.Contains(err.Error(), "NonExistent") {
t.Errorf("Expected error to contain column name, got: %v", err)
}
}
func TestFindColumnByPattern_WildcardNotFound(t *testing.T) {
columns := []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Metadata", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
}
result, err := findColumnByPattern(columns, "Event*")
if err == nil {
t.Fatalf("Expected error, got no error. Result: %s", result)
}
if !strings.Contains(err.Error(), "Event*") {
t.Errorf("Expected error to contain pattern, got: %v", err)
}
}
func TestBuildExtractQueryMssql_OnlyJsonColumns(t *testing.T) {
// Test when all columns are used via JSON extraction
q := ExtractionQuery{
Schema: "dbo",
Table: "Data",
PrimaryKey: "ID",
Columns: []models.ColumnType{
models.NewColumnType("ID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("JsonData", true, false, "VARCHAR", "varchar", "VARCHAR", true, 500, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "JsonData", Field: ".field1"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if !strings.HasPrefix(query, "SELECT [ID], JSON_VALUE([JsonData], '$.field1') AS [JsonData]") {
t.Errorf("Expected JsonData to be replaced by JSON extraction, got: %s", query)
}
if strings.Contains(query, "SELECT [ID], [JsonData]") {
t.Errorf("Expected JsonData to be excluded from raw selection, got: %s", query)
}
}
func TestBuildExtractQueryMssql_JsonColumnsReplaceInOrder(t *testing.T) {
q := ExtractionQuery{
Schema: "dbo",
Table: "Users",
PrimaryKey: "UserID",
Columns: []models.ColumnType{
models.NewColumnType("UserID", false, false, "INT", "int", "INT", false, 0, 0, 0),
models.NewColumnType("Name", true, false, "VARCHAR", "varchar", "VARCHAR", false, 255, 0, 0),
models.NewColumnType("Email", true, false, "VARCHAR", "varchar", "VARCHAR", false, 255, 0, 0),
models.NewColumnType("Metadata", true, false, "NVARCHAR", "nvarchar", "NVARCHAR", true, 4000, 0, 0),
models.NewColumnType("Profile", true, false, "NVARCHAR", "nvarchar", "NVARCHAR", true, 4000, 0, 0),
models.NewColumnType("Settings", true, false, "NVARCHAR", "nvarchar", "NVARCHAR", true, 4000, 0, 0),
},
FromJsonColumns: []config.FromJsonItem{
{Column: "Metadata", Field: ".id"},
{Column: "Profile", Field: ".id"},
{Column: "Settings", Field: ".id"},
},
LowerLimit: ExtractorQueryLimit{IsValid: false},
UpperLimit: ExtractorQueryLimit{IsValid: false},
}
query, err := buildExtractQueryMssql(q)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
expected := "SELECT [UserID], [Name], [Email], JSON_VALUE([Metadata], '$.id') AS [Metadata], JSON_VALUE([Profile], '$.id') AS [Profile], JSON_VALUE([Settings], '$.id') AS [Settings] FROM [dbo].[Users] ORDER BY [UserID] ASC"
if query != expected {
t.Errorf("Unexpected query.\nExpected: %s\nGot: %s", expected, query)
}
}

View File

@@ -1,203 +0,0 @@
package dbwrapper
import (
"context"
"errors"
"fmt"
"strings"
dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db_dialects"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
func init() {
Register(dbdialects.Postgres, func() DbWrapper {
return &postgresDbWrapper{dialect: dbdialects.Postgres}
})
}
type postgresRowResult struct {
row pgx.Row
}
func (pr *postgresRowResult) Scan(dest ...any) error {
return pr.row.Scan(dest...)
}
type postgresRowsResult struct {
columns []string
rows pgx.Rows
}
func (pr *postgresRowsResult) Close() error {
pr.rows.Close()
return nil
}
func (pr *postgresRowsResult) Columns() ([]string, error) {
if pr.columns != nil {
return pr.columns, nil
}
rawColumns := pr.rows.FieldDescriptions()
if rawColumns == nil {
return nil, errors.New("error retrieving columns")
}
columns := make([]string, 0, len(rawColumns))
for _, rc := range rawColumns {
columns = append(columns, rc.Name)
}
return columns, nil
}
func (pr *postgresRowsResult) Err() error {
return pr.rows.Err()
}
func (pr *postgresRowsResult) Next() bool {
return pr.rows.Next()
}
func (pr *postgresRowsResult) Scan(dest ...any) error {
return pr.rows.Scan(dest...)
}
func (pr *postgresRowsResult) Values() ([]any, error) {
return pr.rows.Values()
}
type postgresDbWrapper struct {
db *pgxpool.Pool
dialect string
}
func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error {
pool, err := pgxpool.New(ctx, dbUrl)
if err != nil {
return err
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return err
}
pw.db = pool
return nil
}
func (pw *postgresDbWrapper) Close() error {
pw.db.Close()
return nil
}
func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, err := pw.db.Exec(ctx, query, args...)
if err != nil {
return ExecResult{}, err
}
return ExecResult{AffectedRows: result.RowsAffected()}, nil
}
func (pw *postgresDbWrapper) GetDialect() string {
return pw.dialect
}
func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
rows, err := pw.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return &postgresRowsResult{columns: nil, rows: rows}, nil
}
func (pw *postgresDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult {
row := pw.db.QueryRow(ctx, query, args...)
return &postgresRowResult{row: row}
}
func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows))
if err != nil {
return 0, err
}
return affectedRows, nil
}
func (pw *postgresDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) {
var sbQuery strings.Builder
sbQuery.WriteString("SELECT ")
if len(q.Columns) == 0 {
sbQuery.WriteString("*")
} else {
for i, col := range q.Columns {
switch col.Type() {
case "GEOMETRY":
fmt.Fprintf(&sbQuery, `ST_AsEWKB("%s") AS "%s"`, col.Name(), col.Name())
default:
fmt.Fprintf(&sbQuery, `"%s"`, col.Name())
}
if i < len(q.Columns)-1 {
sbQuery.WriteString(", ")
}
}
}
fmt.Fprintf(&sbQuery, ` FROM "%s"."%s"`, q.Schema, q.Table)
if q.LowerLimit.IsValid || q.UpperLimit.IsValid {
sbQuery.WriteString(" WHERE ")
paramIdx := 1
if q.LowerLimit.IsValid {
fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey)
if q.LowerLimit.IsInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
fmt.Fprintf(&sbQuery, " $%d", paramIdx)
paramIdx++
}
if q.LowerLimit.IsValid && q.UpperLimit.IsValid {
sbQuery.WriteString(" AND ")
}
if q.UpperLimit.IsValid {
fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey)
if q.UpperLimit.IsInclusive {
sbQuery.WriteString(" <=")
} else {
sbQuery.WriteString(" <")
}
fmt.Fprintf(&sbQuery, " $%d", paramIdx)
paramIdx++
}
}
fmt.Fprintf(&sbQuery, ` ORDER BY "%s" ASC`, q.PrimaryKey)
queryString := sbQuery.String()
var queryArgs []any
if q.LowerLimit.IsValid {
queryArgs = append(queryArgs, q.LowerLimit.Value)
}
if q.UpperLimit.IsValid {
queryArgs = append(queryArgs, q.UpperLimit.Value)
}
return pw.Query(ctx, queryString, queryArgs...)
}

View File

@@ -1,55 +0,0 @@
package dbwrapper
import (
"context"
"errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
var MethodNotSupported error = errors.New("Method not supported by driver... yet :P")
type ExecResult struct {
AffectedRows int64
}
type RowsResult interface {
Close() error
Columns() ([]string, error)
Err() error
Next() bool
Scan(dest ...any) error
Values() ([]any, error)
}
type RowResult interface {
Scan(dest ...any) error
}
type ExtractorQueryLimit struct {
IsValid bool
IsInclusive bool
Value int64
}
type ExtractionQuery struct {
Schema string
Table string
PrimaryKey string
Columns []models.ColumnType
LowerLimit ExtractorQueryLimit
UpperLimit ExtractorQueryLimit
FromJsonColumns []config.FromJsonItem
}
type DbWrapper interface {
Close() error
Connect(ctx context.Context, dbUrl string) error
Exec(ctx context.Context, query string, args ...any) (ExecResult, error)
GetDialect() string
Query(ctx context.Context, query string, args ...any) (RowsResult, error)
QueryRow(ctx context.Context, query string, args ...any) RowResult
SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error)
QueryFromObject(ctx context.Context, query ExtractionQuery) (RowsResult, error)
}

View File

@@ -0,0 +1,28 @@
package db
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
func Connect(ctx context.Context, dbURL string) (*pgxpool.Pool, error) {
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
return nil, fmt.Errorf("unable to connect to database: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
return pool, nil
}
func Close(pool *pgxpool.Pool) {
if pool != nil {
pool.Close()
}
}

View File

@@ -1,108 +0,0 @@
package extractors
import (
"context"
"errors"
"slices"
"strings"
"sync"
"sync/atomic"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/sirupsen/logrus"
)
func (ex *GenericExtractor) Consume(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
retryConfig config.RetryConfig,
chPartitionsIn <-chan models.Partition,
chBatchesOut chan<- models.Batch,
chErrorsOut chan<- custom_errors.JobError,
wgActivePartitions *sync.WaitGroup,
rowsRead *int64,
failedPartitionsCount *int32,
fromJsonColumns []config.FromJsonItem,
) {
indexPrimaryKey := slices.IndexFunc(columns, func(col models.ColumnType) bool {
return strings.EqualFold(col.Name(), tableInfo.PrimaryKey)
})
if indexPrimaryKey == -1 {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.JobError{
ShouldCancelJob: true,
Msg: "Primary key not found in provided columns",
}:
}
return
}
for {
if ctx.Err() != nil {
return
}
select {
case <-ctx.Done():
return
case partition, ok := <-chPartitionsIn:
if !ok {
return
}
rowsReadResult, err := ex.ProcessPartitionWithRetries(
ctx,
tableInfo,
columns,
batchSize,
partition,
indexPrimaryKey,
retryConfig,
chBatchesOut,
fromJsonColumns,
)
wgActivePartitions.Done()
if rowsReadResult > 0 {
current := atomic.LoadInt64(rowsRead)
logrus.Debugf("Rows read (partition extracted): +%v [current=%v] (%s.%s)", rowsReadResult, current, tableInfo.Schema, tableInfo.Table)
atomic.AddInt64(rowsRead, int64(rowsReadResult))
}
if err != nil {
atomic.AddInt32(failedPartitionsCount, 1)
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
select {
case <-ctx.Done():
return
case chErrorsOut <- *jobError:
}
} else {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}:
}
}
currentFPCount := atomic.LoadInt32(failedPartitionsCount)
if currentFPCount > int32(retryConfig.MaxFailedPartitions) {
select {
case <-ctx.Done():
return
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed partitions reached"}:
return
}
}
}
}
}
}

View File

@@ -1,41 +0,0 @@
package extractors
import (
"context"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type GenericExtractor struct {
db dbwrapper.DbWrapper
}
func NewExtractor(db dbwrapper.DbWrapper) GenericExtractor {
return GenericExtractor{db: db}
}
func sendBatch(ctx context.Context, chBatchesOut chan<- models.Batch, batch models.Batch) error {
select {
case chBatchesOut <- batch:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func flush(
ctx context.Context,
batchSize int,
batchRows []models.UnknownRowValues,
chBatchesOut chan<- models.Batch,
) error {
if len(batchRows) == 0 {
return nil
}
batch := models.Batch{Id: uuid.New(), Rows: batchRows}
batchRows = make([]models.UnknownRowValues, 0, batchSize)
return sendBatch(ctx, chBatchesOut, batch)
}

View File

@@ -1,77 +0,0 @@
package extractors
import (
"context"
"errors"
"fmt"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
// "github.com/sirupsen/logrus"
)
func (ex *GenericExtractor) ProcessPartitionWithRetries(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
retryConfig config.RetryConfig,
chBatchesOut chan<- models.Batch,
fromJsonColumns []config.FromJsonItem,
) (int64, error) {
var totalRowsRead int64
currentParitition := partition
for {
rowsRead, err := ex.ProcessPartition(
ctx,
tableInfo,
columns,
batchSize,
currentParitition,
indexPrimaryKey,
chBatchesOut,
fromJsonColumns,
)
// logrus.Debugf("Partition %v finished processing (%s.%s)", partition.Id, tableInfo.Schema, tableInfo.Table)
totalRowsRead += rowsRead
if err == nil {
return totalRowsRead, nil
}
if exError, ok := errors.AsType[*custom_errors.ExtractorError](err); ok {
currentParitition.RetryCounter++
if currentParitition.RetryCounter >= retryConfig.Attempts {
return totalRowsRead, &custom_errors.JobError{
Msg: fmt.Sprintf("Partition %v reached max retries (%d)", currentParitition.Id, currentParitition.RetryCounter),
Prev: err,
}
}
if exError.HasLastId {
currentParitition.ParentId = exError.Partition.Id
currentParitition.Id = uuid.New()
currentParitition.Range.Min = exError.LastId
currentParitition.Range.IsMinInclusive = false
}
delay := custom_errors.ComputeBackoffDelay(
currentParitition.RetryCounter,
retryConfig.BaseDelayMs,
retryConfig.MaxDelayMs,
retryConfig.MaxJitterMs,
)
time.Sleep(delay)
continue
}
return totalRowsRead, err
}
}

View File

@@ -1,127 +0,0 @@
package extractors
import (
"context"
"fmt"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/convert"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
// "github.com/sirupsen/logrus"
)
func errorFromLastPartitionRow(
lastRow models.UnknownRowValues,
indexPrimaryKey int,
partition models.Partition,
previousError error,
) error {
lastIdRawValue := lastRow[indexPrimaryKey]
lastId, ok := convert.ToInt64(lastIdRawValue)
if !ok {
currentPartition := partition
currentPartition.RetryCounter = 3
return &custom_errors.ExtractorError{
Partition: currentPartition,
HasLastId: true,
Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()),
}
}
return &custom_errors.ExtractorError{
Partition: partition,
HasLastId: true,
LastId: lastId,
Msg: previousError.Error(),
}
}
func (ex *GenericExtractor) ProcessPartition(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
chBatchesOut chan<- models.Batch,
fromJsonColumns []config.FromJsonItem,
) (int64, error) {
query := dbwrapper.ExtractionQuery{
Schema: tableInfo.Schema,
Table: tableInfo.Table,
PrimaryKey: tableInfo.PrimaryKey,
Columns: columns,
LowerLimit: dbwrapper.ExtractorQueryLimit{
IsValid: partition.HasRange && partition.Range.Min > 0,
IsInclusive: partition.Range.IsMinInclusive,
Value: partition.Range.Min,
},
UpperLimit: dbwrapper.ExtractorQueryLimit{
IsValid: partition.HasRange && partition.Range.Max > 0,
IsInclusive: partition.Range.IsMaxInclusive,
Value: partition.Range.Max,
},
FromJsonColumns: fromJsonColumns,
}
// logrus.Debugf("Processing partition: %+v (%s.%s)", query, tableInfo.Schema, tableInfo.Table)
rows, err := ex.db.QueryFromObject(ctx, query)
if err != nil {
return 0, err
}
defer rows.Close()
batchRows := make([]models.UnknownRowValues, 0, batchSize)
var rowsRead int64 = 0
var lastRow models.UnknownRowValues
for rows.Next() {
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
if err := rows.Scan(scanArgs...); err != nil {
if len(batchRows) == 0 {
return rowsRead, err
}
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
lastRow := batchRows[len(batchRows)-1]
return rowsRead, errorFromLastPartitionRow(lastRow, indexPrimaryKey, partition, err)
}
rowsRead++
lastRow = rowValues
batchRows = append(batchRows, rowValues)
if len(batchRows) >= batchSize {
// logrus.Debugf("Batch size reached, flushing batch with %v rows (rowsRead=%v)", len(batchRows), rowsRead)
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
// logrus.Warnf("Error flushing rows: %v", err)
return rowsRead, err
}
batchRows = make([]models.UnknownRowValues, 0, batchSize)
}
}
if err := flush(ctx, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
if err := rows.Err(); err != nil {
if lastRow != nil {
return rowsRead, errorFromLastPartitionRow(lastRow, indexPrimaryKey, partition, err)
}
return rowsRead, err
}
return rowsRead, nil
}

View File

@@ -1,153 +0,0 @@
package loaders
import (
"context"
"errors"
"sync"
"sync/atomic"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
type loaderAccumulator struct {
batchSize int
rows []models.UnknownRowValues
parents []models.BatchRef
pendingDone int
}
func (a *loaderAccumulator) add(batch models.Batch) {
a.rows = append(a.rows, batch.Rows...)
a.parents = append(a.parents, models.BatchRef{Id: batch.Id})
a.pendingDone++
}
func (a *loaderAccumulator) ready() bool {
return len(a.rows) >= a.batchSize
}
func (a *loaderAccumulator) drainPending(wg *sync.WaitGroup) {
for range a.pendingDone {
wg.Done()
}
}
func sendLoadError(
ctx context.Context,
err error,
retryConfig config.RetryConfig,
failedBatchesCount *int32,
chErrorsOut chan<- custom_errors.JobError,
) bool {
atomic.AddInt32(failedBatchesCount, 1)
var jobErr custom_errors.JobError
if je, ok := errors.AsType[*custom_errors.JobError](err); ok {
jobErr = *je
} else {
jobErr = custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}
}
select {
case <-ctx.Done():
return false
case chErrorsOut <- jobErr:
}
if atomic.LoadInt32(failedBatchesCount) > int32(retryConfig.MaxFailedBatchesLoad) {
select {
case <-ctx.Done():
case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: true, Msg: "Max failed batches (load) reached"}:
}
return false
}
return true
}
func (gl *GenericLoader) Consume(
ctx context.Context,
tableInfo config.TargetTableInfo,
columns []models.ColumnType,
retryConfig config.RetryConfig,
batchSize int,
chBatchesIn <-chan models.Batch,
chErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
rowsLoaded *int64,
failedBatchesCount *int32,
) {
colNames := mapSlice(columns, func(col models.ColumnType) string {
return col.Name()
})
acc := &loaderAccumulator{batchSize: batchSize}
defer acc.drainPending(wgActiveBatches)
flush := func() bool {
if len(acc.rows) == 0 {
return true
}
count := len(acc.parents)
superBatch := models.Batch{
Id: uuid.New(),
ParentBatches: acc.parents,
Rows: acc.rows,
}
processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, superBatch)
for range count {
wgActiveBatches.Done()
}
acc.pendingDone -= count
acc.rows = nil
acc.parents = nil
if err != nil {
return sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut)
}
current := atomic.LoadInt64(rowsLoaded)
logrus.Debugf("Rows loaded (batch loaded): +%v [current=%v] (%s.%s)", processedRows, current, tableInfo.Schema, tableInfo.Table)
atomic.AddInt64(rowsLoaded, int64(processedRows))
return true
}
for {
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
flush()
return
}
if batchSize <= 0 {
processedRows, err := gl.ProcessBatchWithRetries(ctx, tableInfo, colNames, retryConfig, batch)
wgActiveBatches.Done()
if err != nil {
if !sendLoadError(ctx, err, retryConfig, failedBatchesCount, chErrorsOut) {
return
}
continue
}
current := atomic.LoadInt64(rowsLoaded)
logrus.Debugf("Rows loaded: +%v [current=%v] (%s.%s)", processedRows, current, tableInfo.Schema, tableInfo.Table)
atomic.AddInt64(rowsLoaded, int64(processedRows))
continue
}
acc.add(batch)
if acc.ready() {
if !flush() {
return
}
}
}
}
}

View File

@@ -1,603 +0,0 @@
package loaders
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
const testTimeout = 2 * time.Second
type mockResult struct {
err error
}
type mockDbWrapper struct {
mu sync.Mutex
callCount int
results []mockResult
}
func newMockDb(results ...mockResult) *mockDbWrapper {
return &mockDbWrapper{results: results}
}
func (m *mockDbWrapper) SaveMassive(_ context.Context, _ string, _ string, _ []string, rows [][]any) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
idx := m.callCount
m.callCount++
if idx < len(m.results) && m.results[idx].err != nil {
return 0, m.results[idx].err
}
return int64(len(rows)), nil
}
func (m *mockDbWrapper) Close() error { return nil }
func (m *mockDbWrapper) Connect(_ context.Context, _ string) error { return nil }
func (m *mockDbWrapper) Exec(_ context.Context, _ string, _ ...any) (dbwrapper.ExecResult, error) {
return dbwrapper.ExecResult{}, nil
}
func (m *mockDbWrapper) GetDialect() string { return "" }
func (m *mockDbWrapper) Query(_ context.Context, _ string, _ ...any) (dbwrapper.RowsResult, error) {
return nil, nil
}
func (m *mockDbWrapper) QueryRow(_ context.Context, _ string, _ ...any) dbwrapper.RowResult {
return nil
}
func (m *mockDbWrapper) QueryFromObject(_ context.Context, _ dbwrapper.ExtractionQuery) (dbwrapper.RowsResult, error) {
return nil, nil
}
func makeBatch(numRows int) models.Batch {
rows := make([]models.UnknownRowValues, numRows)
for i := range rows {
rows[i] = models.UnknownRowValues{i}
}
return models.Batch{Id: uuid.New(), Rows: rows}
}
func newLoader(db *mockDbWrapper) GenericLoader {
return GenericLoader{db: db}
}
func rc(maxFailed int) config.RetryConfig {
return config.RetryConfig{Attempts: 1, MaxFailedBatchesLoad: maxFailed}
}
func sendBatch(chIn chan<- models.Batch, batch models.Batch, wg *sync.WaitGroup) {
wg.Add(1)
chIn <- batch
}
func runConsume(
ctx context.Context,
gl GenericLoader,
retryConfig config.RetryConfig,
batchSize int,
chIn <-chan models.Batch,
chErr chan<- custom_errors.JobError,
wg *sync.WaitGroup,
rowsLoaded *int64,
failedCount *int32,
) <-chan struct{} {
done := make(chan struct{})
go func() {
gl.Consume(ctx, config.TargetTableInfo{}, nil, retryConfig, batchSize,
chIn, chErr, wg, rowsLoaded, failedCount)
close(done)
}()
return done
}
func waitWg(wg *sync.WaitGroup) <-chan struct{} {
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
return done
}
func dbError() error { return errors.New("connection reset by peer") }
func TestLoaderAccumulator_Add(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5}
b1 := makeBatch(2)
b2 := makeBatch(3)
acc.add(b1)
acc.add(b2)
if len(acc.rows) != 5 {
t.Errorf("expected 5 rows, got %d", len(acc.rows))
}
if len(acc.parents) != 2 {
t.Fatalf("expected 2 parents, got %d", len(acc.parents))
}
if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id {
t.Error("parent IDs do not match source batch IDs in order")
}
if acc.pendingDone != 2 {
t.Errorf("expected pendingDone=2, got %d", acc.pendingDone)
}
}
func TestLoaderAccumulator_Ready(t *testing.T) {
acc := &loaderAccumulator{batchSize: 3}
acc.add(makeBatch(2))
if acc.ready() {
t.Error("should not be ready with 2 rows and batchSize=3")
}
acc.add(makeBatch(1))
if !acc.ready() {
t.Error("should be ready with 3 rows and batchSize=3")
}
}
func TestLoaderAccumulator_DrainPending_ReleasesWg(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5, pendingDone: 3}
var wg sync.WaitGroup
wg.Add(3)
acc.drainPending(&wg)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: drainPending did not call Done() enough times")
}
}
func TestLoaderAccumulator_DrainPending_ZeroPending(t *testing.T) {
acc := &loaderAccumulator{batchSize: 5, pendingDone: 0}
var wg sync.WaitGroup
acc.drainPending(&wg)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out")
}
}
func TestSendLoadError_PlainError_WrappedAsNonFatal(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
result := sendLoadError(context.Background(), errors.New("db error"), rc(10), &failedCount, ch)
if !result {
t.Error("expected true (below threshold)")
}
if atomic.LoadInt32(&failedCount) != 1 {
t.Errorf("expected failedCount=1, got %d", failedCount)
}
select {
case e := <-ch:
if e.ShouldCancelJob {
t.Error("plain error should be wrapped as ShouldCancelJob=false")
}
default:
t.Error("expected an error in the channel")
}
}
func TestSendLoadError_JobError_PassesThrough(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"}
sendLoadError(context.Background(), original, rc(10), &failedCount, ch)
select {
case e := <-ch:
if e.Msg != "custom msg" || e.ShouldCancelJob {
t.Errorf("JobError should pass through unchanged, got %+v", e)
}
default:
t.Error("expected an error in the channel")
}
}
func TestSendLoadError_FatalJobError_BelowThreshold_ReturnsTrue(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
fatal := &custom_errors.JobError{ShouldCancelJob: true, Msg: "unique constraint"}
result := sendLoadError(context.Background(), fatal, rc(10), &failedCount, ch)
if !result {
t.Error("below-threshold fatal error should return true (external cancel expected from JobErrorHandler)")
}
select {
case e := <-ch:
if !e.ShouldCancelJob {
t.Error("fatal JobError should be forwarded with ShouldCancelJob=true")
}
default:
t.Error("expected the fatal error in the channel")
}
}
func TestSendLoadError_ThresholdExceeded_ReturnsFalse(t *testing.T) {
ch := make(chan custom_errors.JobError, 2)
var failedCount int32
result := sendLoadError(context.Background(), errors.New("db error"), rc(0), &failedCount, ch)
if result {
t.Error("expected false when threshold exceeded")
}
if len(ch) != 2 {
t.Fatalf("expected 2 errors (batch error + fatal threshold error), got %d", len(ch))
}
<-ch // batch error
threshold := <-ch
if !threshold.ShouldCancelJob {
t.Error("second error should be the fatal threshold error (ShouldCancelJob=true)")
}
}
func TestSendLoadError_AtThresholdBoundary(t *testing.T) {
ch := make(chan custom_errors.JobError, 6)
var failedCount int32
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("first failure: expected true (below threshold)")
}
if !sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("second failure: expected true (at threshold, not exceeded)")
}
if sendLoadError(context.Background(), errors.New("err"), rc(2), &failedCount, ch) {
t.Error("third failure: expected false (threshold exceeded)")
}
}
func TestSendLoadError_ContextCancelled_ReturnsFalse(t *testing.T) {
ch := make(chan custom_errors.JobError)
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := sendLoadError(ctx, errors.New("db error"), rc(10), &failedCount, ch)
if result {
t.Error("expected false when context is cancelled")
}
if len(ch) != 0 {
t.Error("no error should be sent when context is cancelled")
}
}
func TestConsume_Passthrough_RowsLoaded(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(5), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 5 {
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
}
}
func TestConsume_Passthrough_MultipleBatches_RowsAccumulate(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(3), &wg)
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(4), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 9 {
t.Errorf("expected rowsLoaded=9, got %d", rowsLoaded)
}
}
func TestConsume_Passthrough_WgDoneBeforeErrorHandling(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 2)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: Done() was not called even though processing failed")
}
}
func TestConsume_Passthrough_NonFatalError_Continues(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(3), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 3 {
t.Errorf("expected rowsLoaded=3 (only second batch succeeded), got %d", rowsLoaded)
}
if atomic.LoadInt32(&failedCount) != 1 {
t.Errorf("expected failedCount=1, got %d", failedCount)
}
if len(chErr) == 0 {
t.Error("expected at least one error in chErr for the failed batch")
}
}
func TestConsume_Passthrough_ThresholdExceeded_Exits(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
done := runConsume(context.Background(), gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after threshold exceeded")
}
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out after threshold exit")
}
}
func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 3, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 3 {
t.Errorf("expected rowsLoaded=3, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected 1 SaveMassive call, got %d", db.callCount)
}
}
func TestConsume_Accumulation_FlushOnClose(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(2), &wg)
sendBatch(chIn, makeBatch(3), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 5 {
t.Errorf("expected rowsLoaded=5, got %d", rowsLoaded)
}
if db.callCount != 1 {
t.Errorf("expected exactly 1 SaveMassive call (single flush on close), got %d", db.callCount)
}
}
func TestConsume_Accumulation_RowsLoadedCorrect(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch, 5)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
for range 5 {
sendBatch(chIn, makeBatch(2), &wg)
}
close(chIn)
<-runConsume(context.Background(), gl, rc(0), 4, chIn, chErr, &wg, &rowsLoaded, &failedCount)
wg.Wait()
if rowsLoaded != 10 {
t.Errorf("expected rowsLoaded=10, got %d", rowsLoaded)
}
if db.callCount != 3 {
t.Errorf("expected 3 SaveMassive calls (2 threshold flushes + 1 on close), got %d", db.callCount)
}
}
func TestConsume_Accumulation_WgBalanced_OnContextCancel(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, gl, rc(0), 10, chIn, chErr, &wg, &rowsLoaded, &failedCount)
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after context cancellation")
}
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: drainPending did not release accumulated batches on cancel")
}
}
func TestConsume_Accumulation_ErrorInFlush_WgStillBalanced(t *testing.T) {
db := newMockDb(mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 3)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
sendBatch(chIn, makeBatch(1), &wg)
sendBatch(chIn, makeBatch(1), &wg)
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out: wg.Done() not called after flush error")
}
}
func TestConsume_Accumulation_MultipleFlushes_NonFatalErrors(t *testing.T) {
db := newMockDb(mockResult{err: dbError()}, mockResult{err: dbError()})
gl := newLoader(db)
chIn := make(chan models.Batch, 4)
chErr := make(chan custom_errors.JobError, 6)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
for range 4 {
sendBatch(chIn, makeBatch(1), &wg)
}
close(chIn)
<-runConsume(context.Background(), gl, rc(10), 2, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-waitWg(&wg):
case <-time.After(testTimeout):
t.Fatal("wg.Wait() timed out")
}
if atomic.LoadInt32(&failedCount) != 2 {
t.Errorf("expected failedCount=2, got %d", failedCount)
}
if rowsLoaded != 0 {
t.Errorf("expected rowsLoaded=0 (all batches failed), got %d", rowsLoaded)
}
}
func TestConsume_EmptyInput_NoProcessing(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
close(chIn)
done := runConsume(context.Background(), gl, rc(0), 5, chIn, chErr, &wg, &rowsLoaded, &failedCount)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after empty input channel was closed")
}
if db.callCount != 0 {
t.Errorf("expected no SaveMassive calls, got %d", db.callCount)
}
if rowsLoaded != 0 {
t.Errorf("expected rowsLoaded=0, got %d", rowsLoaded)
}
wg.Wait()
}
func TestConsume_ContextCancellation_Exits(t *testing.T) {
db := newMockDb()
gl := newLoader(db)
chIn := make(chan models.Batch)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
var rowsLoaded int64
var failedCount int32
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, gl, rc(0), 0, chIn, chErr, &wg, &rowsLoaded, &failedCount)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("Consume did not exit after context cancellation")
}
wg.Wait()
}

View File

@@ -1,13 +0,0 @@
package loaders
import (
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
)
type GenericLoader struct {
db dbwrapper.DbWrapper
}
func NewGenericLoader(db dbwrapper.DbWrapper) GenericLoader {
return GenericLoader{db: db}
}

View File

@@ -1,49 +0,0 @@
package loaders
import (
"context"
"errors"
"fmt"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
func (gl *GenericLoader) ProcessBatchWithRetries(
ctx context.Context,
tableInfo config.TargetTableInfo,
colNames []string,
retryConfig config.RetryConfig,
batch models.Batch,
) (int64, error) {
for {
rowsLoaded, err := gl.ProcessBatch(ctx, tableInfo, colNames, batch)
if err == nil {
return rowsLoaded, nil
}
if btError, ok := errors.AsType[*custom_errors.LoaderError](err); ok {
batch.RetryCounter++
if batch.RetryCounter >= retryConfig.Attempts {
return rowsLoaded, &custom_errors.JobError{
Msg: fmt.Sprintf("Batch %v reached max retries (%d)", batch.Id, batch.RetryCounter),
Prev: btError,
}
}
delay := custom_errors.ComputeBackoffDelay(
batch.RetryCounter,
retryConfig.BaseDelayMs,
retryConfig.MaxDelayMs,
retryConfig.MaxJitterMs,
)
time.Sleep(delay)
continue
}
return rowsLoaded, err
}
}

View File

@@ -1,43 +0,0 @@
package loaders
import (
"context"
"errors"
"fmt"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/jackc/pgx/v5/pgconn"
)
func (gl *GenericLoader) ProcessBatch(
ctx context.Context,
tableInfo config.TargetTableInfo,
colNames []string,
batch models.Batch,
) (int64, error) {
_, err := gl.db.SaveMassive(
ctx,
tableInfo.Schema,
tableInfo.Table,
colNames,
batch.Rows,
)
if err != nil {
if pgErr, ok := errors.AsType[*pgconn.PgError](err); ok {
if pgErr.Code == "23505" {
return 0, &custom_errors.JobError{
ShouldCancelJob: true,
Msg: fmt.Sprintf("Fatal error in table %s.%s", tableInfo.Schema, tableInfo.Table),
Prev: err,
}
}
}
return 0, &custom_errors.LoaderError{Batch: batch, Msg: err.Error()}
}
return int64(len(batch.Rows)), nil
}

View File

@@ -1,11 +0,0 @@
package loaders
func mapSlice[T any, V any](input []T, mapper func(T) V) []V {
result := make([]V, len(input))
for i, v := range input {
result[i] = mapper(v)
}
return result
}

View File

@@ -1,133 +0,0 @@
package table_analyzers
import (
"context"
"math"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
func PartitionRangeGenerator(
ctx context.Context,
tableAnalyzer etl.TableAnalyzer,
tableInfo config.TableInfo,
partitionColumn string,
partitionCalculationStrategy string,
rowsPerPartition int64,
jobRange config.RangeConfig,
) ([]models.Partition, error) {
rowsCount, err := tableAnalyzer.EstimateTotalRows(ctx, tableInfo)
logrus.Infof("Estimated rows in source: %v (%s.%s)", rowsCount, tableInfo.Schema, tableInfo.Table)
if err != nil {
return nil, err
}
if rowsCount <= rowsPerPartition {
hasRange := jobRange.Min != nil || jobRange.Max != nil
partition := models.Partition{Id: uuid.New(), HasRange: hasRange, RetryCounter: 0}
if hasRange {
var min, max int64
if jobRange.Min != nil {
min = *jobRange.Min
}
if jobRange.Max != nil {
max = *jobRange.Max
}
partition.Range = models.PartitionRange{
Min: min,
Max: max,
IsMinInclusive: jobRange.IsMinInclusive,
IsMaxInclusive: jobRange.IsMaxInclusive,
}
}
return []models.Partition{partition}, nil
}
partitionsCount := rowsCount / rowsPerPartition
if partitionCalculationStrategy == "ESTIMATION" {
return calculatePartitionsEstimation(ctx, tableAnalyzer, tableInfo, partitionColumn, partitionsCount, jobRange)
}
partitions, err := tableAnalyzer.CalculatePartitionRanges(ctx, tableInfo, partitionColumn, partitionsCount, jobRange)
if err != nil {
return nil, err
}
logrus.Debugf("Partitions count: %v (%s.%s)", len(partitions), tableInfo.Schema, tableInfo.Table)
return partitions, nil
}
func calculatePartitionsEstimation(
ctx context.Context,
tableAnalyzer etl.TableAnalyzer,
tableInfo config.TableInfo,
partitionColumn string,
partitionsCount int64,
rangeConstraint config.RangeConfig,
) ([]models.Partition, error) {
var minValue, maxValue int64
if rangeConstraint.Min != nil && rangeConstraint.Max != nil {
minValue = *rangeConstraint.Min
maxValue = *rangeConstraint.Max
logrus.Infof("Column range for %s.%s.%s: [%d, %d] (user-defined)", tableInfo.Schema, tableInfo.Table, partitionColumn, minValue, maxValue)
} else if rangeConstraint.Min != nil || rangeConstraint.Max != nil {
result, err := tableAnalyzer.QueryMaxMinFromColumn(ctx, tableInfo, partitionColumn)
if err != nil {
return nil, err
}
if rangeConstraint.Min != nil {
minValue = *rangeConstraint.Min
maxValue = result.Max
logrus.Infof("Column range for %s.%s.%s: [%d, %d] (min user-defined)", tableInfo.Schema, tableInfo.Table, partitionColumn, minValue, maxValue)
} else {
minValue = result.Min
maxValue = *rangeConstraint.Max
logrus.Infof("Column range for %s.%s.%s: [%d, %d] (max user-defined)", tableInfo.Schema, tableInfo.Table, partitionColumn, minValue, maxValue)
}
} else {
result, err := tableAnalyzer.QueryMaxMinFromColumn(ctx, tableInfo, partitionColumn)
if err != nil {
return nil, err
}
logrus.Infof("Column range for %s.%s.%s: [%d, %d]", tableInfo.Schema, tableInfo.Table, partitionColumn, result.Min, result.Max)
minValue = result.Min
maxValue = result.Max
}
rangeSize := maxValue - minValue
stepSize := int64(math.Ceil(float64(rangeSize) / float64(partitionsCount)))
partitions := make([]models.Partition, 0, partitionsCount)
for i := range partitionsCount {
partitionMin := minValue + (i * stepSize)
partitionMax := minValue + ((i + 1) * stepSize)
if i == partitionsCount-1 {
partitionMax = maxValue
}
isMinInclusive := i == 0
partition := models.Partition{
Id: uuid.New(),
HasRange: true,
RetryCounter: 0,
Range: models.PartitionRange{
Min: partitionMin,
Max: partitionMax,
IsMinInclusive: isMinInclusive,
IsMaxInclusive: true,
},
}
partitions = append(partitions, partition)
}
return partitions, nil
}

View File

@@ -1,332 +0,0 @@
package table_analyzers
import (
"context"
"testing"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type MockTableAnalyzer struct {
minValue int64
maxValue int64
totalRows int64
capturedRangeConstraint config.RangeConfig
}
func (m *MockTableAnalyzer) QueryColumnTypes(_ context.Context, _ config.TableInfo) ([]models.ColumnType, error) {
return nil, nil
}
func (m *MockTableAnalyzer) EstimateTotalRows(_ context.Context, _ config.TableInfo) (int64, error) {
return m.totalRows, nil
}
func (m *MockTableAnalyzer) QueryMaxMinFromColumn(_ context.Context, _ config.TableInfo, _ string) (etl.MaxMinColumnResult, error) {
return etl.MaxMinColumnResult{Min: m.minValue, Max: m.maxValue}, nil
}
func (m *MockTableAnalyzer) CalculatePartitionRanges(_ context.Context, _ config.TableInfo, _ string, _ int64, rangeConstraint config.RangeConfig) ([]models.Partition, error) {
m.capturedRangeConstraint = rangeConstraint
return []models.Partition{}, nil
}
//go:fix inline
func ptr64(v int64) *int64 { return new(v) }
var testTableInfo = config.TableInfo{Schema: "dbo", Table: "test"}
func TestCalculatePartitionsEstimation_NoOverlap(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{minValue: 0, maxValue: 100}
partitions, err := calculatePartitionsEstimation(ctx, mock, testTableInfo, "id", 4, config.RangeConfig{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) != 4 {
t.Errorf("expected 4 partitions, got %d", len(partitions))
}
for i := 0; i < len(partitions)-1; i++ {
current := partitions[i].Range
next := partitions[i+1].Range
if current.Max == next.Min && current.IsMaxInclusive && next.IsMinInclusive {
t.Errorf("partition %d and %d overlap at value %d (both inclusive)", i, i+1, current.Max)
}
}
}
func TestCalculatePartitionsEstimation_CoverageComplete(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{minValue: 1000, maxValue: 2000}
partitions, err := calculatePartitionsEstimation(ctx, mock, testTableInfo, "id", 5, config.RangeConfig{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if partitions[0].Range.Min != 1000 || !partitions[0].Range.IsMinInclusive {
t.Errorf("first partition should start at 1000 (inclusive), got %d (inclusive=%v)",
partitions[0].Range.Min, partitions[0].Range.IsMinInclusive)
}
if partitions[len(partitions)-1].Range.Max != 2000 {
t.Errorf("last partition should end at 2000, got %d", partitions[len(partitions)-1].Range.Max)
}
}
func TestCalculatePartitionsEstimation_FirstPartitionInclusive(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{minValue: 50, maxValue: 70}
partitions, err := calculatePartitionsEstimation(ctx, mock, testTableInfo, "id", 3, config.RangeConfig{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !partitions[0].Range.IsMinInclusive {
t.Errorf("first partition should have IsMinInclusive=true")
}
if partitions[0].Range.Min != 50 {
t.Errorf("first partition should start at 50, got %d", partitions[0].Range.Min)
}
for i := 1; i < len(partitions); i++ {
if partitions[i].Range.IsMinInclusive {
t.Errorf("partition %d should have IsMinInclusive=false to avoid overlap", i)
}
}
}
func TestPartitionRangeGenerator_Exact_NoRange_PassesEmptyConstraint(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000}
_, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, config.RangeConfig{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mock.capturedRangeConstraint.Min != nil || mock.capturedRangeConstraint.Max != nil {
t.Errorf("expected empty range constraint, got min=%v max=%v",
mock.capturedRangeConstraint.Min, mock.capturedRangeConstraint.Max)
}
}
func TestPartitionRangeGenerator_Exact_BothBounds_PassesBothToAnalyzer(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000}
jobRange := config.RangeConfig{Min: ptr64(200), Max: ptr64(800), IsMinInclusive: true, IsMaxInclusive: true}
_, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
rc := mock.capturedRangeConstraint
if rc.Min == nil || *rc.Min != 200 {
t.Errorf("expected Min=200, got %v", rc.Min)
}
if rc.Max == nil || *rc.Max != 800 {
t.Errorf("expected Max=800, got %v", rc.Max)
}
if !rc.IsMinInclusive || !rc.IsMaxInclusive {
t.Errorf("expected both bounds inclusive, got minInc=%v maxInc=%v", rc.IsMinInclusive, rc.IsMaxInclusive)
}
}
func TestPartitionRangeGenerator_Exact_MinOnly_PassesMinNilMax(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000}
jobRange := config.RangeConfig{Min: ptr64(500)}
_, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
rc := mock.capturedRangeConstraint
if rc.Min == nil || *rc.Min != 500 {
t.Errorf("expected Min=500, got %v", rc.Min)
}
if rc.Max != nil {
t.Errorf("expected Max=nil (no upper bound), got %v", rc.Max)
}
}
func TestPartitionRangeGenerator_Exact_MaxOnly_PassesMaxNilMin(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000}
jobRange := config.RangeConfig{Max: ptr64(300)}
_, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
rc := mock.capturedRangeConstraint
if rc.Min != nil {
t.Errorf("expected Min=nil (no lower bound), got %v", rc.Min)
}
if rc.Max == nil || *rc.Max != 300 {
t.Errorf("expected Max=300, got %v", rc.Max)
}
}
func TestPartitionRangeGenerator_Estimation_BothBounds_UsesUserRange(t *testing.T) {
ctx := context.Background()
// DB min/max differ intentionally — user bounds should take precedence.
mock := &MockTableAnalyzer{totalRows: 1000, minValue: 0, maxValue: 999}
jobRange := config.RangeConfig{Min: ptr64(200), Max: ptr64(700), IsMinInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "ESTIMATION", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) == 0 {
t.Fatal("expected at least one partition")
}
if partitions[0].Range.Min != 200 {
t.Errorf("first partition should start at user min=200, got %d", partitions[0].Range.Min)
}
if partitions[len(partitions)-1].Range.Max != 700 {
t.Errorf("last partition should end at user max=700, got %d", partitions[len(partitions)-1].Range.Max)
}
}
func TestPartitionRangeGenerator_Estimation_MinOnly_QueriesDBForMax(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000, minValue: 0, maxValue: 999}
jobRange := config.RangeConfig{Min: ptr64(500), IsMinInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "ESTIMATION", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) == 0 {
t.Fatal("expected at least one partition")
}
if partitions[0].Range.Min != 500 {
t.Errorf("first partition should start at user min=500, got %d", partitions[0].Range.Min)
}
if partitions[len(partitions)-1].Range.Max != 999 {
t.Errorf("last partition should end at DB max=999, got %d", partitions[len(partitions)-1].Range.Max)
}
}
func TestPartitionRangeGenerator_Estimation_MaxOnly_QueriesDBForMin(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 1000, minValue: 100, maxValue: 999}
jobRange := config.RangeConfig{Max: ptr64(600), IsMaxInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "ESTIMATION", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) == 0 {
t.Fatal("expected at least one partition")
}
if partitions[0].Range.Min != 100 {
t.Errorf("first partition should start at DB min=100, got %d", partitions[0].Range.Min)
}
if partitions[len(partitions)-1].Range.Max != 600 {
t.Errorf("last partition should end at user max=600, got %d", partitions[len(partitions)-1].Range.Max)
}
}
func TestPartitionRangeGenerator_SinglePartition_NoRange(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 50}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, config.RangeConfig{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) != 1 {
t.Fatalf("expected 1 partition, got %d", len(partitions))
}
if partitions[0].HasRange {
t.Error("single partition with no range should have HasRange=false")
}
}
func TestPartitionRangeGenerator_SinglePartition_BothBounds(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 50}
jobRange := config.RangeConfig{Min: ptr64(100), Max: ptr64(200), IsMinInclusive: true, IsMaxInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(partitions) != 1 {
t.Fatalf("expected 1 partition, got %d", len(partitions))
}
p := partitions[0]
if !p.HasRange {
t.Error("expected HasRange=true")
}
if p.Range.Min != 100 || p.Range.Max != 200 {
t.Errorf("expected [100, 200], got [%d, %d]", p.Range.Min, p.Range.Max)
}
if !p.Range.IsMinInclusive || !p.Range.IsMaxInclusive {
t.Errorf("expected both inclusive, got minInc=%v maxInc=%v", p.Range.IsMinInclusive, p.Range.IsMaxInclusive)
}
}
func TestPartitionRangeGenerator_SinglePartition_MinOnly(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 50}
jobRange := config.RangeConfig{Min: ptr64(100), IsMinInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
p := partitions[0]
if !p.HasRange {
t.Error("expected HasRange=true")
}
if p.Range.Min != 100 {
t.Errorf("expected Min=100, got %d", p.Range.Min)
}
if p.Range.Max != 0 {
t.Errorf("expected Max=0 (no upper bound), got %d", p.Range.Max)
}
}
func TestPartitionRangeGenerator_SinglePartition_MaxOnly(t *testing.T) {
ctx := context.Background()
mock := &MockTableAnalyzer{totalRows: 50}
jobRange := config.RangeConfig{Max: ptr64(200), IsMaxInclusive: true}
partitions, err := PartitionRangeGenerator(ctx, mock, testTableInfo, "id", "EXACT", 100, jobRange)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
p := partitions[0]
if !p.HasRange {
t.Error("expected HasRange=true")
}
if p.Range.Min != 0 {
t.Errorf("expected Min=0 (no lower bound), got %d", p.Range.Min)
}
if p.Range.Max != 200 {
t.Errorf("expected Max=200, got %d", p.Range.Max)
}
}

View File

@@ -1,302 +0,0 @@
package table_analyzers
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type MssqlTableAnalyzer struct {
db dbwrapper.DbWrapper
}
func NewMssqlTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer {
return &MssqlTableAnalyzer{db: db}
}
const mssqlColumnMetadataQuery string = `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table AND (c.is_hidden = 0 OR (c.graph_type IS NOT NULL AND c.name LIKE '$%'))
ORDER BY c.column_id;`
type rawColumnMssql struct {
name string
userType string
systemType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (ta *MssqlTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func (ta *MssqlTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnMssql) models.ColumnType {
const nullValue int64 = -1
stringTypes := map[string]bool{"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true}
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
if stringTypes[rawColumn.systemType] {
if rawColumn.systemType == "nvarchar" || rawColumn.systemType == "nchar" {
if rawColumn.maxLength > 0 {
rawColumn.maxLength = rawColumn.maxLength / 2
}
}
rawColumn.precision, rawColumn.scale = nullValue, nullValue
} else if decimalTypes[rawColumn.systemType] {
rawColumn.maxLength = nullValue
} else {
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
}
columnType := models.NewColumnType(
rawColumn.name,
rawColumn.maxLength != nullValue,
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
rawColumn.userType,
rawColumn.systemType,
ta.systemTypeToUnifiedType(rawColumn.systemType),
rawColumn.nullable,
rawColumn.maxLength,
rawColumn.precision,
rawColumn.scale,
)
return columnType
}
func (ta *MssqlTableAnalyzer) QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error) {
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
rows, err := ta.db.Query(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table))
if err != nil {
return nil, err
}
defer rows.Close()
var columnTypes []models.ColumnType
for rows.Next() {
var rawColumn rawColumnMssql
if err := rows.Scan(
&rawColumn.name,
&rawColumn.userType,
&rawColumn.systemType,
&rawColumn.nullable,
&rawColumn.maxLength,
&rawColumn.precision,
&rawColumn.scale,
); err != nil {
return nil, err
}
columnTypes = append(columnTypes, ta.rawColumnToColumnType(rawColumn))
}
return columnTypes, nil
}
func (ta *MssqlTableAnalyzer) EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error) {
query := `
SELECT SUM(p.rows) AS count
FROM sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.partitions p ON t.object_id = p.object_id
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
GROUP BY t.name`
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
var rowsCount int64
err := ta.db.QueryRow(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount)
if err != nil {
return 0, err
}
return rowsCount, nil
}
func (ta *MssqlTableAnalyzer) QueryMaxMinFromColumn(
ctx context.Context,
tableInfo config.TableInfo,
columnName string,
) (etl.MaxMinColumnResult, error) {
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS min_value,
MAX([%s]) AS max_value
FROM [%s].[%s]`, columnName, columnName, tableInfo.Schema, tableInfo.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
result := etl.MaxMinColumnResult{}
err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max)
if err != nil {
return etl.MaxMinColumnResult{}, err
}
return result, nil
}
func (ta *MssqlTableAnalyzer) CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
rangeConstraint config.RangeConfig,
) ([]models.Partition, error) {
whereClause := ""
args := []any{sql.Named("maxPartitions", maxPartitions)}
if rangeConstraint.Min != nil || rangeConstraint.Max != nil {
var conditions []string
if rangeConstraint.Min != nil {
minOp := ">"
if rangeConstraint.IsMinInclusive {
minOp = ">="
}
conditions = append(conditions, fmt.Sprintf("[%s] %s @rangeMin", partitionColumn, minOp))
args = append(args, sql.Named("rangeMin", *rangeConstraint.Min))
}
if rangeConstraint.Max != nil {
maxOp := "<"
if rangeConstraint.IsMaxInclusive {
maxOp = "<="
}
conditions = append(conditions, fmt.Sprintf("[%s] %s @rangeMax", partitionColumn, maxOp))
args = append(args, sql.Named("rangeMax", *rangeConstraint.Max))
}
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
query := fmt.Sprintf(`
SELECT
MIN([%s]) AS lower_limit,
MAX([%s]) AS upper_limit
FROM (SELECT [%s], NTILE(@maxPartitions) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s] %s) AS T
GROUP BY batch_id
ORDER BY batch_id`,
partitionColumn,
partitionColumn,
partitionColumn,
partitionColumn,
tableInfo.Schema,
tableInfo.Table,
whereClause)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
rows, err := ta.db.Query(ctxTimeout, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
partitions := make([]models.Partition, 0, maxPartitions)
for rows.Next() {
partition := models.Partition{
Id: uuid.New(),
HasRange: true,
RetryCounter: 0,
Range: models.PartitionRange{
IsMinInclusive: true,
IsMaxInclusive: true,
},
}
if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil {
return nil, err
}
partitions = append(partitions, partition)
}
if err := rows.Err(); err != nil {
return nil, err
}
return partitions, nil
}

View File

@@ -1,293 +0,0 @@
package table_analyzers
import (
"context"
"fmt"
"strings"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type PostgresTableAnalyzer struct {
db dbwrapper.DbWrapper
}
func NewPostgresTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer {
return &PostgresTableAnalyzer{db: db}
}
const postgresColumnMetadataQuery string = `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
COALESCE(c.character_maximum_length, -1) AS max_length,
COALESCE(c.numeric_precision, -1) AS precision,
COALESCE(c.numeric_scale, -1) AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;`
type rawColumnPostgres struct {
name string
userType string
systemType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType {
const nullValue int64 = -1
stringTypes := map[string]bool{"varchar": true, "char": true, "text": true}
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
if stringTypes[rawColumn.systemType] {
rawColumn.precision, rawColumn.scale = nullValue, nullValue
} else if decimalTypes[rawColumn.systemType] {
rawColumn.maxLength = nullValue
} else {
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
}
return models.NewColumnType(
rawColumn.name,
rawColumn.maxLength != nullValue,
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
rawColumn.userType,
rawColumn.systemType,
ta.systemTypeToUnifiedType(rawColumn.systemType),
rawColumn.nullable,
rawColumn.maxLength,
rawColumn.precision,
rawColumn.scale,
)
}
func (ta *PostgresTableAnalyzer) QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error) {
localCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table)
if err != nil {
return nil, err
}
defer rows.Close()
var colTypes []models.ColumnType
for rows.Next() {
var column rawColumnPostgres
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, err
}
colTypes = append(colTypes, ta.rawColumnToColumnType(column))
}
return colTypes, nil
}
func (ta *PostgresTableAnalyzer) EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error) {
query := `
SELECT reltuples::bigint
FROM pg_class
JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
WHERE pg_namespace.nspname = $1 AND pg_class.relname = $2`
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
var estimate int64
err := ta.db.QueryRow(ctxTimeout, query, tableInfo.Schema, tableInfo.Table).Scan(&estimate)
if err != nil {
return 0, err
}
if estimate < 0 {
countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, tableInfo.Schema, tableInfo.Table)
err = ta.db.QueryRow(ctxTimeout, countQuery).Scan(&estimate)
if err != nil {
return 0, err
}
}
return estimate, nil
}
func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn(
ctx context.Context,
tableInfo config.TableInfo,
columnName string,
) (etl.MaxMinColumnResult, error) {
query := fmt.Sprintf(`SELECT MIN("%s"), MAX("%s") FROM "%s"."%s"`,
columnName, columnName, tableInfo.Schema, tableInfo.Table)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
result := etl.MaxMinColumnResult{}
err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max)
if err != nil {
return etl.MaxMinColumnResult{}, err
}
return result, nil
}
func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
rangeConstraint config.RangeConfig,
) ([]models.Partition, error) {
whereClause := ""
args := []any{maxPartitions}
if rangeConstraint.Min != nil || rangeConstraint.Max != nil {
var conditions []string
if rangeConstraint.Min != nil {
minOp := ">"
if rangeConstraint.IsMinInclusive {
minOp = ">="
}
args = append(args, *rangeConstraint.Min)
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, minOp, len(args)))
}
if rangeConstraint.Max != nil {
maxOp := "<"
if rangeConstraint.IsMaxInclusive {
maxOp = "<="
}
args = append(args, *rangeConstraint.Max)
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, maxOp, len(args)))
}
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
query := fmt.Sprintf(`
SELECT MIN("%s") AS lower_limit, MAX("%s") AS upper_limit
FROM (
SELECT "%s", NTILE($1) OVER (ORDER BY "%s") AS batch_id
FROM "%s"."%s" %s
) AS t
GROUP BY batch_id
ORDER BY batch_id`,
partitionColumn,
partitionColumn,
partitionColumn,
partitionColumn,
tableInfo.Schema,
tableInfo.Table,
whereClause)
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
rows, err := ta.db.Query(ctxTimeout, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
partitions := make([]models.Partition, 0, maxPartitions)
for rows.Next() {
partition := models.Partition{
Id: uuid.New(),
HasRange: true,
RetryCounter: 0,
Range: models.PartitionRange{
IsMinInclusive: true,
IsMaxInclusive: true,
},
}
if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil {
return nil, err
}
partitions = append(partitions, partition)
}
if err := rows.Err(); err != nil {
return nil, err
}
return partitions, nil
}

View File

@@ -1,119 +0,0 @@
package transformers
import (
"context"
"errors"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
type batchAccumulator struct {
batchSize int
rows []models.UnknownRowValues
parents []models.BatchRef
}
func (a *batchAccumulator) add(batch models.Batch) {
a.rows = append(a.rows, batch.Rows...)
a.parents = append(a.parents, models.BatchRef{Id: batch.Id})
}
func (a *batchAccumulator) ready() bool {
return len(a.rows) >= a.batchSize
}
func (a *batchAccumulator) flush(ctx context.Context, chOut chan<- models.Batch, wg *sync.WaitGroup) bool {
if len(a.rows) == 0 {
return true
}
out := models.Batch{
Id: uuid.New(),
ParentBatches: a.parents,
Rows: a.rows,
}
wg.Add(1)
select {
case chOut <- out:
case <-ctx.Done():
wg.Done()
return false
}
a.rows = nil
a.parents = nil
return true
}
func sendTransformError(ctx context.Context, err error, ch chan<- custom_errors.JobError) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
var jobErr custom_errors.JobError
if je, ok := errors.AsType[*custom_errors.JobError](err); ok {
jobErr = *je
} else {
jobErr = custom_errors.JobError{ShouldCancelJob: true, Msg: "Transformation failed", Prev: err}
}
select {
case ch <- jobErr:
case <-ctx.Done():
}
}
func (mssqlTr *MssqlTransformer) Consume(
ctx context.Context,
columns []models.ColumnType,
retryConfig config.RetryConfig,
batchSize int,
chBatchesIn <-chan models.Batch,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
) {
transformationPlan := computeTransformationPlan(columns)
storagePlan := computeStorageTransformationPlan(ctx, mssqlTr.azureClient, mssqlTr.toStorage, columns, mssqlTr.sourceTable)
transformationPlan = append(transformationPlan, storagePlan...)
acc := &batchAccumulator{batchSize: batchSize}
for {
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
acc.flush(ctx, chBatchesOut, wgActiveBatches)
return
}
if len(transformationPlan) > 0 {
if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil {
sendTransformError(ctx, err, chJobErrorsOut)
return
}
}
if batchSize <= 0 {
wgActiveBatches.Add(1)
select {
case chBatchesOut <- batch:
case <-ctx.Done():
wgActiveBatches.Done()
return
}
continue
}
acc.add(batch)
if acc.ready() {
if !acc.flush(ctx, chBatchesOut, wgActiveBatches) {
return
}
}
}
}
}

View File

@@ -1,545 +0,0 @@
package transformers
import (
"context"
"errors"
"sync"
"testing"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
)
const testTimeout = 2 * time.Second
func makeBatch(numRows int) models.Batch {
rows := make([]models.UnknownRowValues, numRows)
for i := range rows {
rows[i] = models.UnknownRowValues{i}
}
return models.Batch{Id: uuid.New(), Rows: rows}
}
func noRetry() config.RetryConfig {
return config.RetryConfig{Attempts: 1}
}
func newTransformer() *MssqlTransformer {
return &MssqlTransformer{}
}
func uuidColumn() models.ColumnType {
return models.NewColumnType("col_uuid", false, false, "uniqueidentifier", "uniqueidentifier", "string", false, 0, 0, 0)
}
func runConsume(
ctx context.Context,
tr *MssqlTransformer,
columns []models.ColumnType,
batchSize int,
chIn <-chan models.Batch,
chOut chan<- models.Batch,
chErr chan<- custom_errors.JobError,
wg *sync.WaitGroup,
) <-chan struct{} {
done := make(chan struct{})
go func() {
tr.Consume(ctx, columns, noRetry(), batchSize, chIn, chOut, chErr, wg)
close(done)
}()
return done
}
func drainOut(chOut <-chan models.Batch, wg *sync.WaitGroup) []models.Batch {
var batches []models.Batch
for {
select {
case b := <-chOut:
batches = append(batches, b)
wg.Done()
default:
return batches
}
}
}
func TestBatchAccumulator_Add(t *testing.T) {
acc := &batchAccumulator{batchSize: 5}
b1 := makeBatch(2)
b2 := makeBatch(3)
acc.add(b1)
acc.add(b2)
if len(acc.rows) != 5 {
t.Errorf("expected 5 rows, got %d", len(acc.rows))
}
if len(acc.parents) != 2 {
t.Fatalf("expected 2 parents, got %d", len(acc.parents))
}
if acc.parents[0].Id != b1.Id || acc.parents[1].Id != b2.Id {
t.Error("parent IDs do not match source batch IDs")
}
}
func TestBatchAccumulator_Ready(t *testing.T) {
acc := &batchAccumulator{batchSize: 3}
acc.add(makeBatch(2))
if acc.ready() {
t.Error("should not be ready with 2 rows and batchSize=3")
}
acc.add(makeBatch(1))
if !acc.ready() {
t.Error("should be ready with 3 rows and batchSize=3")
}
}
func TestBatchAccumulator_Flush_Empty(t *testing.T) {
acc := &batchAccumulator{batchSize: 5}
chOut := make(chan models.Batch, 1)
var wg sync.WaitGroup
if !acc.flush(context.Background(), chOut, &wg) {
t.Error("flush on empty accumulator should return true")
}
if len(chOut) != 0 {
t.Error("flush on empty accumulator should send nothing")
}
}
func TestBatchAccumulator_Flush_Success(t *testing.T) {
acc := &batchAccumulator{batchSize: 2}
b := makeBatch(2)
acc.add(b)
chOut := make(chan models.Batch, 1)
var wg sync.WaitGroup
if !acc.flush(context.Background(), chOut, &wg) {
t.Fatal("flush should return true on success")
}
select {
case out := <-chOut:
wg.Done()
if len(out.Rows) != 2 {
t.Errorf("expected 2 rows in flushed batch, got %d", len(out.Rows))
}
if len(out.ParentBatches) != 1 || out.ParentBatches[0].Id != b.Id {
t.Error("flushed batch should reference the source batch as parent")
}
default:
t.Error("expected a batch in chOut after flush")
}
if len(acc.rows) != 0 || len(acc.parents) != 0 {
t.Error("accumulator state should be reset after flush")
}
wg.Wait()
}
func TestBatchAccumulator_Flush_ContextCancelled(t *testing.T) {
acc := &batchAccumulator{batchSize: 2}
acc.add(makeBatch(2))
chOut := make(chan models.Batch)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
cancel()
if acc.flush(ctx, chOut, &wg) {
t.Error("flush should return false when context is cancelled")
}
wg.Wait()
}
func TestSendTransformError_PlainError(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
sendTransformError(context.Background(), errors.New("something broke"), ch)
select {
case e := <-ch:
if !e.ShouldCancelJob {
t.Error("plain error should produce ShouldCancelJob=true")
}
default:
t.Error("expected a job error in the channel")
}
}
func TestSendTransformError_JobError_Passthrough(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
original := &custom_errors.JobError{ShouldCancelJob: false, Msg: "custom msg"}
sendTransformError(context.Background(), original, ch)
select {
case e := <-ch:
if e.ShouldCancelJob != false || e.Msg != "custom msg" {
t.Errorf("JobError should pass through unchanged, got %+v", e)
}
default:
t.Error("expected a job error in the channel")
}
}
func TestSendTransformError_ContextCancelled_Silent(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sendTransformError(ctx, context.Canceled, ch)
if len(ch) != 0 {
t.Error("context.Canceled should be silently dropped")
}
}
func TestSendTransformError_DeadlineExceeded_Silent(t *testing.T) {
ch := make(chan custom_errors.JobError, 1)
ctx, cancel := context.WithCancel(context.Background())
cancel()
sendTransformError(ctx, context.DeadlineExceeded, ch)
if len(ch) != 0 {
t.Error("context.DeadlineExceeded should be silently dropped")
}
}
func TestConsume_Passthrough_PreservesOriginalBatch(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := makeBatch(3)
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg)
select {
case got := <-chOut:
wg.Done()
if got.Id != batch.Id {
t.Error("passthrough should preserve the original batch ID")
}
if len(got.Rows) != 3 {
t.Errorf("expected 3 rows, got %d", len(got.Rows))
}
case <-time.After(testTimeout):
t.Fatal("timeout waiting for output batch")
}
<-done
wg.Wait()
}
func TestConsume_Passthrough_WaitGroupBalanced(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 3)
chOut := make(chan models.Batch, 3)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 3 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 0, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 3 {
t.Errorf("expected 3 output batches, got %d", len(batches))
}
wg.Wait()
}
func TestConsume_Accumulation_FlushOnThreshold(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 3)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 3 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 3, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 accumulated batch, got %d", len(batches))
}
if len(batches[0].Rows) != 3 {
t.Errorf("expected 3 rows in accumulated batch, got %d", len(batches[0].Rows))
}
wg.Wait()
}
func TestConsume_Accumulation_FlushOnClose(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 2)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
chIn <- makeBatch(1)
chIn <- makeBatch(1)
close(chIn)
done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 batch flushed on close, got %d", len(batches))
}
if len(batches[0].Rows) != 2 {
t.Errorf("expected 2 rows, got %d", len(batches[0].Rows))
}
wg.Wait()
}
func TestConsume_Accumulation_TracksAllParentBatches(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 2)
chOut := make(chan models.Batch, 2)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
b1 := makeBatch(1)
b2 := makeBatch(1)
chIn <- b1
chIn <- b2
close(chIn)
done := runConsume(context.Background(), tr, nil, 10, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 1 {
t.Fatalf("expected 1 output batch, got %d", len(batches))
}
parents := batches[0].ParentBatches
if len(parents) != 2 {
t.Fatalf("expected 2 parent refs, got %d", len(parents))
}
if parents[0].Id != b1.Id || parents[1].Id != b2.Id {
t.Error("parent IDs should match source batch IDs in order")
}
wg.Wait()
}
func TestConsume_Accumulation_MultipleFlushes(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch, 5)
chOut := make(chan models.Batch, 5)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
for range 5 {
chIn <- makeBatch(1)
}
close(chIn)
done := runConsume(context.Background(), tr, nil, 2, chIn, chOut, chErr, &wg)
<-done
batches := drainOut(chOut, &wg)
if len(batches) != 3 {
t.Fatalf("expected 3 output batches (2+2+1 rows), got %d", len(batches))
}
totalRows := 0
for _, b := range batches {
totalRows += len(b.Rows)
}
if totalRows != 5 {
t.Errorf("expected 5 total rows across all batches, got %d", totalRows)
}
wg.Wait()
}
func TestConsume_EmptyInput_NoOutput(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
close(chIn)
done := runConsume(context.Background(), tr, nil, 5, chIn, chOut, chErr, &wg)
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("timeout: Consume did not exit after empty input channel was closed")
}
if len(chOut) != 0 {
t.Error("expected no output for empty input")
}
wg.Wait()
}
func TestConsume_TransformError_SendsJobError(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}},
}
chIn <- batch
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
select {
case err := <-chErr:
if !err.ShouldCancelJob {
t.Error("transform error should set ShouldCancelJob=true")
}
case <-time.After(testTimeout):
t.Fatal("timeout: expected a job error from transform failure")
}
<-done
wg.Wait()
}
func TestConsume_TransformError_NoOutputForwarded(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{[]byte{1, 2, 3}}},
}
chIn <- batch
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
if len(chOut) != 0 {
t.Error("no batch should be forwarded when transformation fails")
}
wg.Wait()
}
func TestConsume_ContextCancellation_Exits(t *testing.T) {
tr := newTransformer()
chIn := make(chan models.Batch)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
done := runConsume(ctx, tr, nil, 0, chIn, chOut, chErr, &wg)
cancel()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatal("timeout: Consume did not exit after context cancellation")
}
wg.Wait()
}
func TestConsume_Transform_DatetimeConvertedToUTC(t *testing.T) {
tr := newTransformer()
col := models.NewColumnType("col_dt", false, false, "datetime", "datetime", "timestamp", false, 0, 0, 0)
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
nonUTC := time.Date(2024, 1, 15, 12, 0, 0, 0, time.FixedZone("EST", -5*3600))
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{nonUTC}},
}
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
select {
case got := <-chOut:
wg.Done()
result, ok := got.Rows[0][0].(time.Time)
if !ok {
t.Fatal("expected time.Time in output row")
}
if result.Location() != time.UTC {
t.Errorf("expected UTC location after transform, got %v", result.Location())
}
default:
t.Error("expected an output batch")
}
wg.Wait()
}
func TestConsume_Transform_NilValueSkipped(t *testing.T) {
tr := newTransformer()
col := uuidColumn()
chIn := make(chan models.Batch, 1)
chOut := make(chan models.Batch, 1)
chErr := make(chan custom_errors.JobError, 1)
var wg sync.WaitGroup
batch := models.Batch{
Id: uuid.New(),
Rows: []models.UnknownRowValues{{nil}},
}
chIn <- batch
close(chIn)
done := runConsume(context.Background(), tr, []models.ColumnType{col}, 0, chIn, chOut, chErr, &wg)
<-done
select {
case got := <-chOut:
wg.Done()
if got.Rows[0][0] != nil {
t.Error("nil value should pass through unchanged")
}
default:
t.Error("expected an output batch even when value is nil")
}
if len(chErr) != 0 {
t.Error("nil value should not produce an error")
}
wg.Wait()
}

View File

@@ -1,21 +0,0 @@
package transformers
import (
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
)
type MssqlTransformer struct {
toStorage config.ToStorageConfig
sourceTable config.SourceTableInfo
azureClient *azure.Client
}
func NewMssqlTransformer(toStorage config.ToStorageConfig, sourceTable config.SourceTableInfo, azureClient *azure.Client) etl.Transformer {
return &MssqlTransformer{
toStorage: toStorage,
sourceTable: sourceTable,
azureClient: azureClient,
}
}

View File

@@ -1,181 +0,0 @@
package transformers
import (
"context"
"fmt"
"path"
"strings"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
func computeTransformationPlan(columns []models.ColumnType) []etl.ColumnTransformPlan {
var plan []etl.ColumnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "uniqueidentifier":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return mssqlUuidToBigEndian(b)
}
return v, nil
},
})
case "geometry", "geography":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return wkbToEwkbWithSrid(b, 4326)
}
return v, nil
},
})
case "datetime", "datetime2":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if t, ok := v.(time.Time); ok {
return ensureUTC(t), nil
}
return v, nil
},
})
}
}
return plan
}
func computePostgresTransformationPlan(columns []models.ColumnType) []etl.ColumnTransformPlan {
var plan []etl.ColumnTransformPlan
for i, col := range columns {
switch col.SystemType() {
case "int2", "int4", "int8", "integer", "smallint", "bigint":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if v64, ok := ToInt64(v); ok {
return v64, nil
}
return v, nil
},
})
case "uuid":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
switch b := v.(type) {
case []byte:
if b != nil {
return bigEndianToMssqlUuid(b)
}
case [16]byte:
return bigEndianToMssqlUuid(b[:])
}
return v, nil
},
})
case "geometry":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return ewkbToMssqlGeo(b, false)
}
return v, nil
},
})
case "geography":
plan = append(plan, etl.ColumnTransformPlan{
Index: i,
Fn: func(v any) (any, error) {
if b, ok := v.([]byte); ok && b != nil {
return ewkbToMssqlGeo(b, true)
}
return v, nil
},
})
}
}
return plan
}
func computeStorageTransformationPlan(
ctx context.Context,
azureClient *azure.Client,
toStorage config.ToStorageConfig,
sourceColumns []models.ColumnType,
sourceTable config.SourceTableInfo,
) []etl.ColumnTransformPlan {
if azureClient == nil || len(toStorage.Columns) == 0 {
return nil
}
colIndex := make(map[string]int, len(sourceColumns))
for i, col := range sourceColumns {
colIndex[strings.ToUpper(col.Name())] = i
}
var plan []etl.ColumnTransformPlan
for _, storageCol := range toStorage.Columns {
if storageCol.Mode != "REFERENCE_ONLY" {
logrus.Warnf("to_storage: unsupported mode %q for column %s — skipping", storageCol.Mode, storageCol.Source)
continue
}
idx, ok := colIndex[strings.ToUpper(storageCol.Source)]
if !ok {
logrus.Warnf("to_storage: source column %q not found in source schema — skipping", storageCol.Source)
continue
}
sourceColName := storageCol.Source
schema := sourceTable.Schema
table := sourceTable.Table
plan = append(plan, etl.ColumnTransformPlan{
Index: idx,
Fn: func(v any) (any, error) {
if v == nil {
return nil, nil
}
b, ok := v.([]byte)
if !ok {
logrus.Warnf("to_storage: expected []byte for %s.%s.%s, got %T — passing through", schema, table, sourceColName, v)
return v, nil
}
// start := time.Now()
blobPath := path.Join(storageCol.Prefix, uuid.New().String())
blobURL, err := azureClient.UploadAndGetURL(ctx, blobPath, b)
if err != nil {
return nil, &custom_errors.JobError{
Msg: fmt.Sprintf("Error uploading %s.%s.%s", schema, table, sourceColName),
Prev: err,
}
}
// logrus.Debugf(`Succesfully uploaded "%s", (%vms)`, blobURL, time.Since(start).Milliseconds())
return blobURL, nil
},
})
}
return plan
}

View File

@@ -1,72 +0,0 @@
package transformers
import (
"context"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type PostgresTransformer struct {
sourceTable config.SourceTableInfo
}
func NewPostgresTransformer(sourceTable config.SourceTableInfo) etl.Transformer {
return &PostgresTransformer{sourceTable: sourceTable}
}
func (pgTr *PostgresTransformer) Consume(
ctx context.Context,
columns []models.ColumnType,
retryConfig config.RetryConfig,
batchSize int,
chBatchesIn <-chan models.Batch,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
) {
transformationPlan := computePostgresTransformationPlan(columns)
acc := &batchAccumulator{batchSize: batchSize}
for {
select {
case <-ctx.Done():
return
case batch, ok := <-chBatchesIn:
if !ok {
acc.flush(ctx, chBatchesOut, wgActiveBatches)
return
}
if len(transformationPlan) > 0 {
if err := ProcessBatchWithRetries(ctx, &batch, transformationPlan, retryConfig); err != nil {
sendTransformError(ctx, err, chJobErrorsOut)
return
}
}
if batchSize <= 0 {
wgActiveBatches.Add(1)
select {
case chBatchesOut <- batch:
case <-ctx.Done():
wgActiveBatches.Done()
return
}
continue
}
acc.add(batch)
if acc.ready() {
if !acc.flush(ctx, chBatchesOut, wgActiveBatches) {
return
}
}
}
}
}

View File

@@ -1,73 +0,0 @@
package transformers
import (
"context"
"errors"
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
const processBatchCtxCheck = 4096
func ProcessBatchWithRetries(
ctx context.Context,
batch *models.Batch,
transformationPlan []etl.ColumnTransformPlan,
retryConfig config.RetryConfig,
) error {
for i, rowValues := range batch.Rows {
if i%processBatchCtxCheck == 0 {
if err := ctx.Err(); err != nil {
return err
}
}
for _, task := range transformationPlan {
val := rowValues[task.Index]
if val == nil {
continue
}
var lastErr error
success := false
for attempt := 0; attempt < retryConfig.Attempts; attempt++ {
transformed, err := task.Fn(val)
if err == nil {
rowValues[task.Index] = transformed
success = true
break
}
lastErr = err
if jobError, ok := errors.AsType[*custom_errors.JobError](err); ok {
if jobError.ShouldCancelJob {
return jobError
}
}
if attempt == retryConfig.Attempts-1 {
break
}
delay := custom_errors.ComputeBackoffDelay(
attempt,
retryConfig.BaseDelayMs,
retryConfig.MaxDelayMs,
retryConfig.MaxJitterMs,
)
time.Sleep(delay)
}
if !success {
return lastErr
}
}
}
return nil
}

View File

@@ -1,61 +0,0 @@
package etl
import (
"context"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type TransformerFunc func(any) (any, error)
type ColumnTransformPlan struct {
Index int
Fn TransformerFunc
}
type Transformer interface {
Consume(
ctx context.Context,
columns []models.ColumnType,
retryConfig config.RetryConfig,
batchSize int,
chBatchesIn <-chan models.Batch,
chBatchesOut chan<- models.Batch,
chJobErrorsOut chan<- custom_errors.JobError,
wgActiveBatches *sync.WaitGroup,
)
}
type MaxMinColumnResult struct {
Max int64
Min int64
}
type TableAnalyzer interface {
QueryColumnTypes(
ctx context.Context,
tableInfo config.TableInfo,
) ([]models.ColumnType, error)
EstimateTotalRows(
ctx context.Context,
tableInfo config.TableInfo,
) (int64, error)
QueryMaxMinFromColumn(
ctx context.Context,
tableInfo config.TableInfo,
columnName string,
) (MaxMinColumnResult, error)
CalculatePartitionRanges(
ctx context.Context,
tableInfo config.TableInfo,
partitionColumn string,
maxPartitions int64,
rangeConstraint config.RangeConfig,
) ([]models.Partition, error)
}

View File

@@ -1,46 +0,0 @@
package models
import (
"time"
"github.com/google/uuid"
)
type UnknownRowValues = []any
type BatchRef struct {
Id uuid.UUID
PartitionId uuid.UUID
}
type Batch struct {
Id uuid.UUID
ParentBatches []BatchRef
Rows []UnknownRowValues
RetryCounter int
}
type PartitionRange struct {
Min int64
Max int64
IsMinInclusive bool
IsMaxInclusive bool
}
type Partition struct {
Id uuid.UUID
ParentId uuid.UUID
Range PartitionRange
HasRange bool
RetryCounter int
}
type JobResult struct {
JobName string
StartTime time.Time
Duration time.Duration
RowsRead int64
RowsLoaded int64
RowsFailed int64
Error error
}

View File

@@ -1,137 +0,0 @@
# Plan: Bidirectional Transformation Support
## Goal
Make the transformation pipeline direction-aware. Currently hardcoded to MSSQL → PG; add support for PG → MSSQL by applying inverse transformations when `SourceDbType == "postgres"`.
Excluded: `to_storage` Azure blob upload (not reversible).
---
## Hardcoded wiring to fix
| File | Line | Change |
|---|---|---|
| `cmd/go_migrate/process.go` | 51 | Branch on `SourceDbType`: `"sqlserver"``NewMssqlTransformer`, `"postgres"``NewPostgresTransformer` |
| `cmd/go_migrate/main.go` | 166167 | Branch on source/target type for both `TableAnalyzer` selections |
---
## Transformations
### Forward (MSSQL → PG) — unchanged
| Column type | Function | File |
|---|---|---|
| `uniqueidentifier` | `mssqlUuidToBigEndian` | `utils.go:9` |
| `geometry`/`geography` | `wkbToEwkbWithSrid` | `utils.go:25` |
| `datetime`/`datetime2` | `ensureUTC` | `utils.go:57` |
### Inverse (PG → MSSQL) — new
| PG system type | Action |
|---|---|
| `uuid` | `bigEndianToMssqlUuid`: re-swap bytes [0-3], [4-5], [6-7] |
| `geometry` | `ewkbToMssqlGeo(v, false)`: strip SRID → WKB → `WkbToUdtGeo` |
| `geography` | `ewkbToMssqlGeo(v, true)`: strip SRID → WKB → `WkbToUdtGeo` |
| `timestamp`/`timestamptz` | no-op |
**Geometry note**: MSSQL rejects plain WKB via bulk protocol. Must use `mssqlclrgeo.WkbToUdtGeo(wkb, isGeography)` (already in go.mod). PG extractor already emits EWKB via `ST_AsEWKB()`.
---
## New utility functions (`transformers/utils.go`)
### `bigEndianToMssqlUuid(v []byte) []byte`
```
out[0..3] = v[3,2,1,0]
out[4..5] = v[5,4]
out[6..7] = v[7,6]
out[8..15] = v[8..15]
```
### `ewkbToMssqlGeo(ewkb []byte, isGeography bool) ([]byte, error)`
1. Read byte-order flag from `ewkb[0]`
2. Read geometry type word bytes [1..4]
3. If SRID flag (`0x20000000`) is set: strip bytes [5..8], clear flag in type word
4. Call `mssqlclrgeo.WkbToUdtGeo(wkb, isGeography)`
---
## New files
### `transformers/postgres.go`
```go
func NewPostgresTransformer(...) *Transformer {
// same signature as NewMssqlTransformer
// calls computePostgresTransformationPlan instead
// does NOT call computeStorageTransformationPlan
}
```
### `computePostgresTransformationPlan` in `transformers/plan.go`
Iterates `sourceColTypes` (from PG analyzer), applies inverse closures by system type.
---
## PostgreSQL table analyzer stubs to implement (`table_analyzers/postgres.go`)
Required for PG-as-source partitioned extraction:
### `EstimateTotalRows`
```sql
SELECT reltuples::bigint FROM pg_class
JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
WHERE pg_namespace.nspname = $schema AND pg_class.relname = $table
```
Fallback to `COUNT(*)` if `reltuples < 0`.
### `QueryMaxMinFromColumn`
```sql
SELECT MIN("col"), MAX("col") FROM "schema"."table"
```
### `CalculatePartitionRanges`
Use min/max from above + `rowsPerPartition` to compute boundaries. Mirror the logic from `MssqlTableAnalyzer.CalculatePartitionRanges`.
---
## Test cases
### TC-1: `bigEndianToMssqlUuid` — round-trip
- Input: run `mssqlUuidToBigEndian` on a known 16-byte MSSQL UUID → produces PG UUID
- Assert: `bigEndianToMssqlUuid(pgUUID)` == original MSSQL UUID bytes
- Also assert nil input → nil output (no panic)
### TC-2: `bigEndianToMssqlUuid` — known vector
- Input: `[0x6b,0xa7,0xb8,0x10, 0x9d,0xad, 0x11,0xd1, 0x80,0xb4,0x00,0xc0,0x4f,0xd4,0x30,0xc8]` (RFC 4122 nil UUID variant)
- Assert: bytes [0-3] are reversed, [4-5] reversed, [6-7] reversed, [8-15] identical
### TC-3: `ewkbToMssqlGeo` — geometry round-trip
- Input: generate a polygon via `go-geom` + `wkb.Marshal` → plain WKB
- Forward: run `wkbToEwkbWithSrid` → EWKB
- Inverse: run `ewkbToMssqlGeo(ewkb, false)` → CLR/UDT bytes
- Assert: no error, output is non-empty `[]byte`
### TC-4: `ewkbToMssqlGeo` — nil input
- Input: nil
- Assert: returns nil, nil (no panic)
### TC-5: `ewkbToMssqlGeo` — EWKB without SRID flag
- Input: plain WKB (no SRID flag set)
- Assert: function still calls `WkbToUdtGeo` and returns without error
### TC-6: Transformer factory selection
- Given `SourceDbType == "postgres"``NewPostgresTransformer` is selected
- Given `SourceDbType == "sqlserver"``NewMssqlTransformer` is selected
---
## Files changed (summary)
1. `cmd/go_migrate/process.go` — transformer factory branch
2. `cmd/go_migrate/main.go` — analyzer selection branch
3. `internal/app/etl/transformers/utils.go` — 2 new functions
4. `internal/app/etl/transformers/plan.go``computePostgresTransformationPlan`
5. `internal/app/etl/transformers/postgres.go` *(new)*
6. `internal/app/etl/table_analyzers/postgres.go` — 3 stub implementations

View File

@@ -1,44 +0,0 @@
package main
import (
"context"
"fmt"
"log"
"math/rand"
"sync"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/azure"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
)
func main() {
cfg := config.App.AzureStorage
containerName := cfg.Container
client, err := azure.NewClient(cfg)
if err != nil {
log.Fatalf("Error creando cliente: %v", err)
}
ctx := context.Background()
var wg sync.WaitGroup
for i := 1; i <= 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
blobName := fmt.Sprintf("%sarchivo-%d.txt", cfg.Prefix, id)
content := fmt.Sprintf("Contenido aleatorio: %d", rand.Intn(100000))
err := client.UploadBuffer(ctx, containerName, blobName, []byte(content))
if err != nil {
log.Printf("Fallo al subir %s: %v", blobName, err)
} else {
fmt.Printf("Subido exitosamente: %s\n", blobName)
}
}(i)
}
wg.Wait()
}

View File

@@ -1,30 +1,131 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
log "github.com/sirupsen/logrus"
"github.com/goccy/go-yaml"
)
// Estructuras para mapear el YAML
type RetryConfig struct {
Attempts int `yaml:"attempts"`
}
type DBInfo struct {
Schema string `yaml:"schema"`
Table string `yaml:"table"`
PrimaryKey string `yaml:"primary_key,omitempty"` // omitempty si no siempre existe
}
// JobSettings contiene los campos que se comparten entre 'defaults' y cada 'job'
type JobSettings struct {
MaxExtractors *int `yaml:"max_extractors"`
MaxLoaders *int `yaml:"max_loaders"`
QueueSize *int `yaml:"queue_size"`
ChunkSize *int `yaml:"chunk_size"`
ChunksPerBatch *int `yaml:"chunks_per_batch"`
TruncateTarget *bool `yaml:"truncate_target"`
TruncateMethod *string `yaml:"truncate_method"`
Retry *RetryConfig `yaml:"retry"`
}
type Job struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Source DBInfo `yaml:"source"`
Target DBInfo `yaml:"target"`
PreSQL []string `yaml:"pre_sql"`
PostSQL []string `yaml:"post_sql"`
// Incrustamos los settings para permitir los overrides
JobSettings `yaml:",inline"`
}
type Config struct {
MaxParallelWorkers int `yaml:"max_parallel_workers"`
Defaults JobSettings `yaml:"defaults"`
Jobs []Job `yaml:"jobs"`
}
func main() {
log.SetLevel(log.DebugLevel)
configPath := flag.String("config", "", "path to migration config file")
flag.Parse()
if flag.NArg() > 1 {
log.Fatalf("only one config file path is allowed")
}
if *configPath == "" && flag.NArg() == 1 {
*configPath = flag.Arg(0)
}
migrationConfig, err := config.ReadMigrationConfig(*configPath)
yamlFile, err := os.ReadFile("config.yaml")
if err != nil {
log.Fatalf("error leyendo configuracion: %v", err)
log.Fatalf("Error leyendo archivo: %v", err)
}
log.Debugf("Config: %+v", migrationConfig)
var config Config
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
log.Fatalf("Error parseando YAML: %v", err)
}
fmt.Printf("Configuración cargada. Trabajos: %d\n", len(config.Jobs))
for i, job := range config.Jobs {
jobPtr := &config.Jobs[i]
if job.MaxExtractors == nil {
jobPtr.MaxExtractors = config.Defaults.MaxExtractors
}
if job.MaxLoaders == nil {
jobPtr.MaxLoaders = config.Defaults.MaxLoaders
}
if job.QueueSize == nil {
jobPtr.QueueSize = config.Defaults.QueueSize
}
if job.ChunkSize == nil {
jobPtr.ChunkSize = config.Defaults.ChunkSize
}
if job.ChunksPerBatch == nil {
jobPtr.ChunksPerBatch = config.Defaults.ChunksPerBatch
}
if job.TruncateTarget == nil {
jobPtr.TruncateTarget = config.Defaults.TruncateTarget
}
if job.TruncateMethod == nil {
jobPtr.TruncateMethod = config.Defaults.TruncateMethod
}
if job.Retry == nil {
jobPtr.Retry = config.Defaults.Retry
}
}
printConfig(config)
}
func printConfig(config Config) {
fmt.Println("Max parallel workers: ", config.MaxParallelWorkers)
fmt.Println("Defaults:")
fmt.Printf("\tMaxExtractors: %v\n", *config.Defaults.MaxExtractors)
fmt.Printf("\tMaxLoaders: %v\n", *config.Defaults.MaxLoaders)
fmt.Printf("\tQueueSize: %v\n", *config.Defaults.QueueSize)
fmt.Printf("\tChunkSize: %v\n", *config.Defaults.ChunkSize)
fmt.Printf("\tChunksPerBatch: %v\n", *config.Defaults.ChunksPerBatch)
fmt.Printf("\tTruncateTarget: %v\n", *config.Defaults.TruncateTarget)
fmt.Printf("\tTruncateMethod: %v\n", *config.Defaults.TruncateMethod)
fmt.Printf("\tRetry: %v\n", *config.Defaults.Retry)
fmt.Println("Jobs:")
for i, job := range config.Jobs {
fmt.Printf("Job Name: %v\n", job.Name)
fmt.Printf("\tEnabled: %v\n", job.Enabled)
fmt.Printf("\tSource: %v\n", job.Source)
fmt.Printf("\tTarget: %v\n", job.Target)
fmt.Printf("\tMaxExtractors: %v\n", *job.MaxExtractors)
fmt.Printf("\tMaxLoaders: %v\n", *job.MaxLoaders)
fmt.Printf("\tQueueSize: %v\n", *job.QueueSize)
fmt.Printf("\tChunkSize: %v\n", *job.ChunkSize)
fmt.Printf("\tChunksPerBatch: %v\n", *job.ChunksPerBatch)
fmt.Printf("\tTruncateTarget: %v\n", *job.TruncateTarget)
fmt.Printf("\tTruncateMethod: %v\n", *job.TruncateMethod)
fmt.Printf("\tRetry: %v\n", *job.Retry)
fmt.Printf("\tPreSQL: %v\n", job.PreSQL)
fmt.Printf("\tPostSQL: %v\n", job.PostSQL)
if i >= 2 {
fmt.Println("Skipping remaining jobs...")
}
}
}

View File

@@ -12,9 +12,9 @@ import (
)
const (
totalRows int = 2_000_000
chunkSize int = 5000
queueSize int = 8
totalRows int = 1_000_000
chunkSize int = 50_000
queueSize int = 4
)
func main() {
@@ -40,14 +40,6 @@ func main() {
seedManzanas(ctx, db)
})
// wgSeed.Go(func() {
// seedPuertos(ctx, db)
// })
// wgSeed.Go(func() {
// seedSiteHolderAttach(ctx, db)
// })
wgSeed.Wait()
}

View File

@@ -1,227 +0,0 @@
package main
import (
"bytes"
"context"
"database/sql"
"fmt"
"math/rand"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
var siteHolderAttachJob = MigrationJob{
Schema: "Infraestructura",
Table: "SITE_HOLDER__ATTACH",
}
func seedSiteHolderAttach(ctx context.Context, db *sql.DB) error {
maxOid, err := getMaxGDBArchiveOidForAttach(ctx, db)
if err != nil {
log.Fatal("Error getting max GDB_ARCHIVE_OID: ", err)
}
log.Infof("Starting SITE_HOLDER__ATTACH data generation from GDB_ARCHIVE_OID: %d", maxOid+1)
rowsChan := make(chan []UnknownRowValues, queueSize)
var wgRowGenerator sync.WaitGroup
wgRowGenerator.Go(func() {
generateSiteHolderAttachRows(ctx, maxOid, totalRows, chunkSize, rowsChan)
})
columns := []string{
"GDB_ARCHIVE_OID",
"REL_GLOBALID",
"CONTENT_TYPE",
"ATT_NAME",
"DATA_SIZE",
"DATA",
"GLOBALID",
"GDB_FROM_DATE",
"GDB_TO_DATE",
"ATTACHMENTID",
}
if err := loadRowsMssql(ctx, siteHolderAttachJob, columns, db, rowsChan); err != nil {
return fmt.Errorf("Error loading rows (SITE_HOLDER__ATTACH): %w", err)
}
log.Info("Data generation and loading completed successfully (SITE_HOLDER__ATTACH)")
wgRowGenerator.Wait()
return nil
}
func getMaxGDBArchiveOidForAttach(ctx context.Context, db *sql.DB) (int, error) {
var maxOid sql.NullInt64
query := fmt.Sprintf(`
SELECT ISNULL(MAX(GDB_ARCHIVE_OID), 0)
FROM [%s].[%s]
`, siteHolderAttachJob.Schema, siteHolderAttachJob.Table)
err := db.QueryRowContext(ctx, query).Scan(&maxOid)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
if !maxOid.Valid {
return 0, nil
}
return int(maxOid.Int64), nil
}
func generateSiteHolderAttachRows(
ctx context.Context,
startOid int,
totalRows int,
chunkSize int,
out chan<- []UnknownRowValues,
) {
defer close(out)
rowsGenerated := 0
currentChunk := make([]UnknownRowValues, 0, chunkSize)
for i := range totalRows {
gdbArchiveOid := startOid + i + 1
row := generateSiteHolderAttachRow(gdbArchiveOid)
currentChunk = append(currentChunk, row)
rowsGenerated++
if len(currentChunk) == chunkSize {
select {
case out <- currentChunk:
log.Debugf("Sent SITE_HOLDER__ATTACH chunk with %d rows", len(currentChunk))
case <-ctx.Done():
log.Info("Context cancelled, stopping SITE_HOLDER__ATTACH row generation")
return
}
currentChunk = make([]UnknownRowValues, 0, chunkSize)
}
if rowsGenerated%100_000 == 0 {
logSiteHolderAttachSampleRow(rowsGenerated, row)
}
}
if len(currentChunk) > 0 {
select {
case out <- currentChunk:
log.Debugf("Sent final SITE_HOLDER__ATTACH chunk with %d rows", len(currentChunk))
case <-ctx.Done():
log.Info("Context cancelled, stopping SITE_HOLDER__ATTACH row generation")
}
}
log.Infof("Finished generating %d SITE_HOLDER__ATTACH rows", rowsGenerated)
}
func generateSiteHolderAttachRow(gdbArchiveOid int) UnknownRowValues {
dateLowerLimit, _ := time.Parse(time.RFC3339, "2020-12-31T23:59:59Z")
dateUpperLimit, _ := time.Parse(time.RFC3339, "2025-12-31T23:59:59Z")
relGlobalID, _ := uuid.New().MarshalBinary()
contentType := generateRandomContentType()
attName := generateRandomAttachmentName()
binaryData := generateRandomBinaryContent()
dataSize := len(binaryData)
globalID, _ := uuid.New().MarshalBinary()
gdbFromDate := generateRandomTimestamp(dateLowerLimit, dateUpperLimit)
gdbToDate, _ := time.Parse(time.RFC3339, "9999-12-31T23:59:59Z")
attachmentID := rand.Intn(10000) + 1
return UnknownRowValues{
gdbArchiveOid,
relGlobalID,
contentType,
attName,
dataSize,
binaryData,
globalID,
gdbFromDate,
gdbToDate,
attachmentID,
}
}
func generateRandomContentType() string {
contentTypes := []string{
"text/plain",
"application/pdf",
"image/jpeg",
"image/png",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/csv",
"application/json",
}
return contentTypes[rand.Intn(len(contentTypes))]
}
func generateRandomAttachmentName() string {
extensions := []string{".txt", ".pdf", ".jpg", ".png", ".doc", ".docx", ".csv", ".json"}
baseName := generateRandomString(20)
extension := extensions[rand.Intn(len(extensions))]
return baseName + extension
}
func generateRandomBinaryContent() []byte {
sizeOptions := []int{100, 500, 1000, 5000, 10000, 50000, 100000}
size := sizeOptions[rand.Intn(len(sizeOptions))]
var buf bytes.Buffer
lineCount := rand.Intn(size/50) + 1
for range lineCount {
line := generateRandomString(rand.Intn(80) + 20)
buf.WriteString(line)
buf.WriteString("\n")
}
for buf.Len() < size {
randomText := generateRandomString(rand.Intn(100) + 50)
buf.WriteString(randomText)
buf.WriteString("\n")
}
result := buf.Bytes()
if len(result) > size {
result = result[:size]
}
return result
}
func logSiteHolderAttachSampleRow(id int, rowValues UnknownRowValues) {
dataBytes := rowValues[5].([]byte)
log.Infof(`
Sample SITE_HOLDER__ATTACH row #%d:
GDB_ARCHIVE_OID: %v
REL_GLOBALID: [binary UUID]
CONTENT_TYPE: %v
ATT_NAME: %v
DATA_SIZE: %v
DATA: [%d bytes of binary content]
GLOBALID: [binary UUID]
GDB_FROM_DATE: %v
GDB_TO_DATE: %v
ATTACHMENTID: %v
`,
id,
rowValues[0],
rowValues[2],
rowValues[3],
rowValues[4],
len(dataBytes),
rowValues[7],
rowValues[8],
rowValues[9],
)
}

View File

@@ -8,32 +8,13 @@ import (
"time"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
)
func Connect(ctx context.Context, dbURL string) (*pgxpool.Pool, error) {
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
return nil, fmt.Errorf("unable to connect to database: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
return pool, nil
}
func Close(pool *pgxpool.Pool) {
if pool != nil {
pool.Close()
}
}
func main() {
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
@@ -46,8 +27,8 @@ func main() {
ctxSource, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
sourcePool, err := Connect(ctxSource, config.App.SourceDbUrl)
defer Close(sourcePool)
sourcePool, err := db.Connect(ctxSource, config.App.SourceDbUrl)
defer db.Close(sourcePool)
if err != nil {
log.Fatal(err)
}
@@ -56,8 +37,8 @@ func main() {
ctxTarget, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
targetPool, err := Connect(ctxTarget, config.App.TargetDbUrl)
defer Close(targetPool)
targetPool, err := db.Connect(ctxTarget, config.App.TargetDbUrl)
defer db.Close(targetPool)
if err != nil {
log.Fatal(err)
}