feat: enhance column type querying functions for improved error handling and SQL compatibility

This commit is contained in:
2026-04-04 10:17:20 -05:00
parent 46c08323ad
commit de0d4a5516
2 changed files with 104 additions and 29 deletions

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@@ -69,7 +70,7 @@ func connectToDatabases() (*sql.DB, *pgxpool.Pool, error) {
wg.Wait() wg.Wait()
if sourceDbErr != nil || targetDbErr != nil { if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, fmt.Errorf("Unable to connect to databases: %w (source), %w (target)", sourceDbErr, targetDbErr) return nil, nil, errors.New("Unable to connect to databases")
} }
return sourceDb, targetDb, nil return sourceDb, targetDb, nil

View File

@@ -3,12 +3,12 @@ package main
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb" _ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -52,53 +52,127 @@ func main() {
log.Info("Migration completed successfully!") log.Info("Migration completed successfully!")
} }
func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]sql.ColumnType, error) { func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
query := fmt.Sprintf(`SELECT * FROM [%s].[%s] WHERE 0 = 1`, migrationJob.Schema, migrationJob.Table) query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel() defer cancel()
rows, err := db.QueryContext(ctx, query) rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil { if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err) return nil, fmt.Errorf("Error querying column types: %w", err)
} }
defer rows.Close() defer rows.Close()
colTypesPointers, err := rows.ColumnTypes() var colTypes []ColumnType
if err != nil {
return nil, err for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
} }
colTypes := make([]sql.ColumnType, 0, len(colTypesPointers)) colTypes = append(colTypes, column)
for _, c := range colTypesPointers {
colTypes = append(colTypes, *c)
} }
return colTypes, nil return colTypes, nil
} }
func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]pgconn.FieldDescription, error) { func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := fmt.Sprintf(`SELECT * FROM "%s"."%s" WHERE 0 = 1`, migrationJob.Schema, migrationJob.Table) query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel() defer cancel()
rows, err := db.Query(ctx, query) rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err) return nil, fmt.Errorf("Error querying column types: %w", err)
} }
defer rows.Close() defer rows.Close()
colTypes := rows.FieldDescriptions() var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
if scanMaxLength != nil {
column.maxLength = *scanMaxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
}
if column.systemType == "decimal" {
if scanPrecision != nil && scanScale != nil {
column.precision = *scanPrecision
column.scale = *scanScale
column.hasPrecisionScale = true
}
}
colTypes = append(colTypes, column)
}
return colTypes, nil return colTypes, nil
} }
func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]sql.ColumnType, []pgconn.FieldDescription, error) { func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error var sourceDbErr error
var targetDbErr error var targetDbErr error
var sourceColTypes []sql.ColumnType var sourceColTypes []ColumnType
var targetColTypes []pgconn.FieldDescription var targetColTypes []ColumnType
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -118,7 +192,7 @@ func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob Mig
wg.Wait() wg.Wait()
if sourceDbErr != nil || targetDbErr != nil { if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, fmt.Errorf("Unable to connect to databases: %w (source), %w (target)", sourceDbErr, targetDbErr) return nil, nil, errors.New("Error querying column types")
} }
return sourceColTypes, targetColTypes, nil return sourceColTypes, targetColTypes, nil