package main import ( "context" "database/sql" "errors" "fmt" "strings" "sync" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/microsoft/go-mssqldb" log "github.com/sirupsen/logrus" ) func GetUnifiedType(systemType string) string { systemType = strings.ToLower(systemType) if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" { return "STRING" } if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" { return "INTEGER" } if systemType == "decimal" || systemType == "numeric" { return "DECIMAL" } if systemType == "float" || systemType == "real" || systemType == "double precision" { return "FLOAT" } if systemType == "bit" || systemType == "boolean" { return "BOOLEAN" } if systemType == "date" { return "DATE" } if systemType == "time" || systemType == "time without time zone" { return "TIME" } if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" { return "TIMESTAMP" } if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" { return "BINARY" } if systemType == "uniqueidentifier" || systemType == "uuid" { return "UUID" } if systemType == "json" { return "JSON" } if systemType == "geometry" || systemType == "geography" { return "GEOMETRY" } return strings.ToUpper(systemType) } func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) models.ColumnType { stringTypes := map[string]bool{ "varchar": true, "char": true, "character": true, "text": true, "character varying": true, } decimalTypes := map[string]bool{ "decimal": true, "numeric": true, } if stringTypes[column.systemType] { if maxLength != nil { column.maxLength = *maxLength column.hasMaxLength = true } else { column.maxLength = -1 column.hasMaxLength = false } column.hasPrecisionScale = false column.precision = -1 column.scale = -1 } else if decimalTypes[column.systemType] { column.hasMaxLength = false column.maxLength = -1 if precision != nil && scale != nil { column.precision = *precision column.scale = *scale column.hasPrecisionScale = true } else { column.precision = -1 column.scale = -1 column.hasPrecisionScale = false } } else { column.hasMaxLength = false column.maxLength = -1 column.hasPrecisionScale = false column.precision = -1 column.scale = -1 } column.unifiedType = GetUnifiedType(column.systemType) colType := models.NewColumnType( column.name, column.hasMaxLength, column.hasPrecisionScale, column.userType, column.systemType, column.unifiedType, column.nullable, column.maxLength, column.precision, column.scale, ) return colType } func GetColumnTypesPostgres(db *pgxpool.Pool, tableInfo config.TargetTableInfo) ([]models.ColumnType, error) { query := ` SELECT c.column_name AS name, c.data_type AS user_type, c.udt_name AS system_type, (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, c.character_maximum_length AS max_length, c.numeric_precision AS precision, c.numeric_scale AS scale FROM information_schema.columns c WHERE c.table_schema = $1 AND c.table_name = $2 ORDER BY c.ordinal_position; ` ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() rows, err := db.Query(ctx, query, tableInfo.Schema, tableInfo.Table) if err != nil { return nil, fmt.Errorf("Error querying column types: %w", err) } defer rows.Close() var colTypes []models.ColumnType for rows.Next() { var column ColumnType var scanMaxLength *int64 var scanPrecision *int64 var scanScale *int64 if err := rows.Scan( &column.name, &column.userType, &column.systemType, &column.nullable, &scanMaxLength, &scanPrecision, &scanScale, ); err != nil { return nil, fmt.Errorf("Error scanning column type results: %w", err) } colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale)) } return colTypes, nil } func MapMssqlColumn(column ColumnType) models.ColumnType { stringTypes := map[string]bool{ "varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true, } decimalTypes := map[string]bool{ "decimal": true, "numeric": true, } if stringTypes[column.systemType] { column.hasMaxLength = true if column.systemType == "nvarchar" || column.systemType == "nchar" { if column.maxLength > 0 { column.maxLength = column.maxLength / 2 } } column.hasPrecisionScale = false column.precision = -1 column.scale = -1 } else if decimalTypes[column.systemType] { column.hasMaxLength = false column.maxLength = -1 column.hasPrecisionScale = true } else { column.hasMaxLength = false column.maxLength = -1 column.hasPrecisionScale = false column.precision = -1 column.scale = -1 } column.unifiedType = GetUnifiedType(column.systemType) colType := models.NewColumnType( column.name, column.hasMaxLength, column.hasPrecisionScale, column.userType, column.systemType, column.unifiedType, column.nullable, column.maxLength, column.precision, column.scale, ) return colType } func GetColumnTypesMssql(db *sql.DB, tableInfo config.SourceTableInfo) ([]models.ColumnType, error) { query := ` SELECT c.name AS name, t.name AS user_type, CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type, c.is_nullable AS nullable, c.max_length AS max_length, c.precision AS precision, c.scale AS scale FROM sys.columns c JOIN sys.types t ON c.user_type_id = t.user_type_id LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id JOIN sys.tables st ON c.object_id = st.object_id JOIN sys.schemas s ON st.schema_id = s.schema_id WHERE s.name = @schema AND st.name = @table ORDER BY c.column_id; ` ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() rows, err := db.QueryContext(ctx, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)) if err != nil { return nil, fmt.Errorf("Error querying column types: %w", err) } defer rows.Close() var colTypes []models.ColumnType for rows.Next() { var column ColumnType if err := rows.Scan( &column.name, &column.userType, &column.systemType, &column.nullable, &column.maxLength, &column.precision, &column.scale, ); err != nil { return nil, fmt.Errorf("Error scanning column type results: %W", err) } if strings.HasPrefix(column.name, "graph_id") && column.systemType == "bigint" { continue } colTypes = append(colTypes, MapMssqlColumn(column)) } return colTypes, nil } func GetColumnTypes( sourceDb *sql.DB, targetDb *pgxpool.Pool, sourceTable config.SourceTableInfo, targetTable config.TargetTableInfo, ) ([]models.ColumnType, []models.ColumnType, error) { var sourceDbErr error var targetDbErr error var sourceColTypes []models.ColumnType var targetColTypes []models.ColumnType var wg sync.WaitGroup wg.Go(func() { sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, sourceTable) if sourceDbErr != nil { log.Error("Error (sourceDb): ", sourceDbErr) } }) wg.Go(func() { targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, targetTable) if targetDbErr != nil { log.Error("Error (targetDb): ", targetDbErr) } }) wg.Wait() if sourceDbErr != nil || targetDbErr != nil { return nil, nil, errors.New("Error querying column types") } return sourceColTypes, targetColTypes, nil }