feat: refactor column type handling and streamline database column queries

This commit is contained in:
2026-04-04 13:30:07 -05:00
parent de0d4a5516
commit 828fc57121
3 changed files with 299 additions and 163 deletions

View File

@@ -3,12 +3,12 @@ package main
type ColumnType struct { type ColumnType struct {
name string name string
hasNullable bool
hasMaxLength bool hasMaxLength bool
hasPrecisionScale bool hasPrecisionScale bool
userType string userType string
systemType string systemType string
unifiedType string
nullable bool nullable bool
maxLength int64 maxLength int64
precision int64 precision int64
@@ -35,6 +35,10 @@ func (c *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return c.precision, c.scale, c.hasPrecisionScale return c.precision, c.scale, c.hasPrecisionScale
} }
func (c *ColumnType) Nullable() (nullable, ok bool) { func (c *ColumnType) Nullable() bool {
return c.nullable, c.hasNullable return c.nullable
}
func (c *ColumnType) Type() string {
return c.unifiedType
} }

View File

@@ -0,0 +1,279 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func GetUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "character": true, "text": true, "character varying": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
if maxLength != nil {
column.maxLength = *maxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
column.hasMaxLength = false
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
if precision != nil && scale != nil {
column.precision = *precision
column.scale = *scale
column.hasPrecisionScale = true
} else {
column.precision = -1
column.scale = -1
column.hasPrecisionScale = false
}
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesPostgres(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale))
}
return colTypes, nil
}
func MapMssqlColumn(column ColumnType) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
column.hasMaxLength = true
if column.systemType == "nvarchar" || column.systemType == "nchar" {
if column.maxLength > 0 {
column.maxLength = column.maxLength / 2
}
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = true
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesMssql(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
colTypes = append(colTypes, MapMssqlColumn(column))
}
return colTypes, nil
}
func GetColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, migrationJob)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, migrationJob)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
}

View File

@@ -1,15 +1,6 @@
package main package main
import ( import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb" _ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -40,160 +31,22 @@ func main() {
defer targetDb.Close() defer targetDb.Close()
for _, job := range migrationJobs { for _, job := range migrationJobs {
sourceColTypes, targetColTypes, err := queryColumnTypes(sourceDb, targetDb, job) sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job)
if err != nil { if err != nil {
log.Fatal("Unexpected error: ", err) log.Fatal("Unexpected error: ", err)
} }
log.Debugf("Source col types: %+v", sourceColTypes) logColumnTypes(sourceColTypes, "Source col types")
log.Debugf("Target col types: %+v", targetColTypes) logColumnTypes(targetColTypes, "Target col types")
} }
log.Info("Migration completed successfully!") log.Info("Migration completed successfully!")
} }
func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) { func logColumnTypes(columnTypes []ColumnType, label string) {
query := ` log.Info(label)
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) for _, col := range columnTypes {
defer cancel() log.Infof("%+v", col)
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
} }
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
colTypes = append(colTypes, column)
}
return colTypes, nil
}
func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
if scanMaxLength != nil {
column.maxLength = *scanMaxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
}
if column.systemType == "decimal" {
if scanPrecision != nil && scanScale != nil {
column.precision = *scanPrecision
column.scale = *scanScale
column.hasPrecisionScale = true
}
}
colTypes = append(colTypes, column)
}
return colTypes, nil
}
func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = querySourceColTypes(sourceDb, migrationJob)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = queryTargetColTypes(targetDb, migrationJob)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
} }