feat: refactor column type handling and streamline database column queries

This commit is contained in:
2026-04-04 13:30:07 -05:00
parent de0d4a5516
commit 828fc57121
3 changed files with 299 additions and 163 deletions

View File

@@ -3,16 +3,16 @@ package main
type ColumnType struct {
name string
hasNullable bool
hasMaxLength bool
hasPrecisionScale bool
userType string
systemType string
nullable bool
maxLength int64
precision int64
scale int64
userType string
systemType string
unifiedType string
nullable bool
maxLength int64
precision int64
scale int64
}
func (c *ColumnType) Name() string {
@@ -35,6 +35,10 @@ func (c *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return c.precision, c.scale, c.hasPrecisionScale
}
func (c *ColumnType) Nullable() (nullable, ok bool) {
return c.nullable, c.hasNullable
func (c *ColumnType) Nullable() bool {
return c.nullable
}
func (c *ColumnType) Type() string {
return c.unifiedType
}

View File

@@ -0,0 +1,279 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
func GetUnifiedType(systemType string) string {
systemType = strings.ToLower(systemType)
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
return "STRING"
}
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
return "INTEGER"
}
if systemType == "decimal" || systemType == "numeric" {
return "DECIMAL"
}
if systemType == "float" || systemType == "real" || systemType == "double precision" {
return "FLOAT"
}
if systemType == "bit" || systemType == "boolean" {
return "BOOLEAN"
}
if systemType == "date" {
return "DATE"
}
if systemType == "time" || systemType == "time without time zone" {
return "TIME"
}
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
return "TIMESTAMP"
}
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
return "BINARY"
}
if systemType == "uniqueidentifier" || systemType == "uuid" {
return "UUID"
}
if systemType == "json" {
return "JSON"
}
if systemType == "geometry" || systemType == "geography" {
return "GEOMETRY"
}
return strings.ToUpper(systemType)
}
func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "character": true, "text": true, "character varying": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
if maxLength != nil {
column.maxLength = *maxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
column.hasMaxLength = false
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
if precision != nil && scale != nil {
column.precision = *precision
column.scale = *scale
column.hasPrecisionScale = true
} else {
column.precision = -1
column.scale = -1
column.hasPrecisionScale = false
}
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesPostgres(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale))
}
return colTypes, nil
}
func MapMssqlColumn(column ColumnType) ColumnType {
stringTypes := map[string]bool{
"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true,
}
decimalTypes := map[string]bool{
"decimal": true, "numeric": true,
}
if stringTypes[column.systemType] {
column.hasMaxLength = true
if column.systemType == "nvarchar" || column.systemType == "nchar" {
if column.maxLength > 0 {
column.maxLength = column.maxLength / 2
}
}
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
} else if decimalTypes[column.systemType] {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = true
} else {
column.hasMaxLength = false
column.maxLength = -1
column.hasPrecisionScale = false
column.precision = -1
column.scale = -1
}
column.unifiedType = GetUnifiedType(column.systemType)
return column
}
func GetColumnTypesMssql(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
colTypes = append(colTypes, MapMssqlColumn(column))
}
return colTypes, nil
}
func GetColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, migrationJob)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, migrationJob)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
}

View File

@@ -1,15 +1,6 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/microsoft/go-mssqldb"
log "github.com/sirupsen/logrus"
)
@@ -40,160 +31,22 @@ func main() {
defer targetDb.Close()
for _, job := range migrationJobs {
sourceColTypes, targetColTypes, err := queryColumnTypes(sourceDb, targetDb, job)
sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job)
if err != nil {
log.Fatal("Unexpected error: ", err)
}
log.Debugf("Source col types: %+v", sourceColTypes)
log.Debugf("Target col types: %+v", targetColTypes)
logColumnTypes(sourceColTypes, "Source col types")
logColumnTypes(targetColTypes, "Target col types")
}
log.Info("Migration completed successfully!")
}
func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.name AS name,
t.name AS user_type,
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
c.is_nullable AS nullable,
c.max_length AS max_length,
c.precision AS precision,
c.scale AS scale
FROM sys.columns c
JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
JOIN sys.tables st ON c.object_id = st.object_id
JOIN sys.schemas s ON st.schema_id = s.schema_id
WHERE s.name = @schema AND st.name = @table
ORDER BY c.column_id;
`
func logColumnTypes(columnTypes []ColumnType, label string) {
log.Info(label)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
for _, col := range columnTypes {
log.Infof("%+v", col)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&column.maxLength,
&column.precision,
&column.scale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %W", err)
}
colTypes = append(colTypes, column)
}
return colTypes, nil
}
func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
query := `
SELECT
c.column_name AS name,
c.data_type AS user_type,
c.udt_name AS system_type,
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
c.character_maximum_length AS max_length,
c.numeric_precision AS precision,
c.numeric_scale AS scale
FROM information_schema.columns c
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position;
`
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
if err != nil {
return nil, fmt.Errorf("Error querying column types: %w", err)
}
defer rows.Close()
var colTypes []ColumnType
for rows.Next() {
var column ColumnType
var scanMaxLength *int64
var scanPrecision *int64
var scanScale *int64
if err := rows.Scan(
&column.name,
&column.userType,
&column.systemType,
&column.nullable,
&scanMaxLength,
&scanPrecision,
&scanScale,
); err != nil {
return nil, fmt.Errorf("Error scanning column type results: %w", err)
}
if scanMaxLength != nil {
column.maxLength = *scanMaxLength
column.hasMaxLength = true
} else {
column.maxLength = -1
}
if column.systemType == "decimal" {
if scanPrecision != nil && scanScale != nil {
column.precision = *scanPrecision
column.scale = *scanScale
column.hasPrecisionScale = true
}
}
colTypes = append(colTypes, column)
}
return colTypes, nil
}
func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
var sourceDbErr error
var targetDbErr error
var sourceColTypes []ColumnType
var targetColTypes []ColumnType
var wg sync.WaitGroup
wg.Go(func() {
sourceColTypes, sourceDbErr = querySourceColTypes(sourceDb, migrationJob)
if sourceDbErr != nil {
log.Error("Error (sourceDb): ", sourceDbErr)
}
})
wg.Go(func() {
targetColTypes, targetDbErr = queryTargetColTypes(targetDb, migrationJob)
if targetDbErr != nil {
log.Error("Error (targetDb): ", targetDbErr)
}
})
wg.Wait()
if sourceDbErr != nil || targetDbErr != nil {
return nil, nil, errors.New("Error querying column types")
}
return sourceColTypes, targetColTypes, nil
}