diff --git a/cmd/go_migrate/colum-type.go b/cmd/go_migrate/colum-type.go index 1288e15..cfd76a4 100644 --- a/cmd/go_migrate/colum-type.go +++ b/cmd/go_migrate/colum-type.go @@ -3,16 +3,16 @@ package main type ColumnType struct { name string - hasNullable bool hasMaxLength bool hasPrecisionScale bool - userType string - systemType string - nullable bool - maxLength int64 - precision int64 - scale int64 + userType string + systemType string + unifiedType string + nullable bool + maxLength int64 + precision int64 + scale int64 } func (c *ColumnType) Name() string { @@ -35,6 +35,10 @@ func (c *ColumnType) DecimalSize() (precision, scale int64, ok bool) { return c.precision, c.scale, c.hasPrecisionScale } -func (c *ColumnType) Nullable() (nullable, ok bool) { - return c.nullable, c.hasNullable +func (c *ColumnType) Nullable() bool { + return c.nullable +} + +func (c *ColumnType) Type() string { + return c.unifiedType } diff --git a/cmd/go_migrate/inspect-columns.go b/cmd/go_migrate/inspect-columns.go new file mode 100644 index 0000000..b97043c --- /dev/null +++ b/cmd/go_migrate/inspect-columns.go @@ -0,0 +1,279 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + _ "github.com/microsoft/go-mssqldb" + log "github.com/sirupsen/logrus" +) + +func GetUnifiedType(systemType string) string { + systemType = strings.ToLower(systemType) + + if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" { + return "STRING" + } + + if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" { + return "INTEGER" + } + + if systemType == "decimal" || systemType == "numeric" { + return "DECIMAL" + } + + if systemType == "float" || systemType == "real" || systemType == "double precision" { + return "FLOAT" + } + + if systemType == "bit" || systemType == "boolean" { + return "BOOLEAN" + } + + if systemType == "date" { + return "DATE" + } + if systemType == "time" || systemType == "time without time zone" { + return "TIME" + } + if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" { + return "TIMESTAMP" + } + + if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" { + return "BINARY" + } + + if systemType == "uniqueidentifier" || systemType == "uuid" { + return "UUID" + } + + if systemType == "json" { + return "JSON" + } + + if systemType == "geometry" || systemType == "geography" { + return "GEOMETRY" + } + + return strings.ToUpper(systemType) +} + +func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType { + stringTypes := map[string]bool{ + "varchar": true, "char": true, "character": true, "text": true, "character varying": true, + } + + decimalTypes := map[string]bool{ + "decimal": true, "numeric": true, + } + + if stringTypes[column.systemType] { + if maxLength != nil { + column.maxLength = *maxLength + column.hasMaxLength = true + } else { + column.maxLength = -1 + column.hasMaxLength = false + } + column.hasPrecisionScale = false + column.precision = -1 + column.scale = -1 + } else if decimalTypes[column.systemType] { + column.hasMaxLength = false + column.maxLength = -1 + if precision != nil && scale != nil { + column.precision = *precision + column.scale = *scale + column.hasPrecisionScale = true + } else { + column.precision = -1 + column.scale = -1 + column.hasPrecisionScale = false + } + } else { + column.hasMaxLength = false + column.maxLength = -1 + column.hasPrecisionScale = false + column.precision = -1 + column.scale = -1 + } + + column.unifiedType = GetUnifiedType(column.systemType) + + return column +} + +func GetColumnTypesPostgres(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) { + query := ` +SELECT + c.column_name AS name, + c.data_type AS user_type, + c.udt_name AS system_type, + (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, + c.character_maximum_length AS max_length, + c.numeric_precision AS precision, + c.numeric_scale AS scale +FROM information_schema.columns c +WHERE c.table_schema = $1 AND c.table_name = $2 +ORDER BY c.ordinal_position; +` + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table) + if err != nil { + return nil, fmt.Errorf("Error querying column types: %w", err) + } + defer rows.Close() + + var colTypes []ColumnType + + for rows.Next() { + var column ColumnType + var scanMaxLength *int64 + var scanPrecision *int64 + var scanScale *int64 + + if err := rows.Scan( + &column.name, + &column.userType, + &column.systemType, + &column.nullable, + &scanMaxLength, + &scanPrecision, + &scanScale, + ); err != nil { + return nil, fmt.Errorf("Error scanning column type results: %w", err) + } + + colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale)) + } + + return colTypes, nil +} + +func MapMssqlColumn(column ColumnType) ColumnType { + stringTypes := map[string]bool{ + "varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true, + } + + decimalTypes := map[string]bool{ + "decimal": true, "numeric": true, + } + + if stringTypes[column.systemType] { + column.hasMaxLength = true + if column.systemType == "nvarchar" || column.systemType == "nchar" { + if column.maxLength > 0 { + column.maxLength = column.maxLength / 2 + } + } + column.hasPrecisionScale = false + column.precision = -1 + column.scale = -1 + } else if decimalTypes[column.systemType] { + column.hasMaxLength = false + column.maxLength = -1 + column.hasPrecisionScale = true + } else { + column.hasMaxLength = false + column.maxLength = -1 + column.hasPrecisionScale = false + column.precision = -1 + column.scale = -1 + } + + column.unifiedType = GetUnifiedType(column.systemType) + + return column +} + +func GetColumnTypesMssql(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) { + query := ` +SELECT + c.name AS name, + t.name AS user_type, + CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type, + c.is_nullable AS nullable, + c.max_length AS max_length, + c.precision AS precision, + c.scale AS scale +FROM sys.columns c +JOIN sys.types t ON c.user_type_id = t.user_type_id +LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id +JOIN sys.tables st ON c.object_id = st.object_id +JOIN sys.schemas s ON st.schema_id = s.schema_id +WHERE s.name = @schema AND st.name = @table +ORDER BY c.column_id; +` + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table)) + if err != nil { + return nil, fmt.Errorf("Error querying column types: %w", err) + } + defer rows.Close() + + var colTypes []ColumnType + + for rows.Next() { + var column ColumnType + + if err := rows.Scan( + &column.name, + &column.userType, + &column.systemType, + &column.nullable, + &column.maxLength, + &column.precision, + &column.scale, + ); err != nil { + return nil, fmt.Errorf("Error scanning column type results: %W", err) + } + + colTypes = append(colTypes, MapMssqlColumn(column)) + } + + return colTypes, nil +} + +func GetColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) { + var sourceDbErr error + var targetDbErr error + var sourceColTypes []ColumnType + var targetColTypes []ColumnType + var wg sync.WaitGroup + + wg.Go(func() { + sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, migrationJob) + if sourceDbErr != nil { + log.Error("Error (sourceDb): ", sourceDbErr) + } + }) + + wg.Go(func() { + targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, migrationJob) + if targetDbErr != nil { + log.Error("Error (targetDb): ", targetDbErr) + } + }) + + wg.Wait() + + if sourceDbErr != nil || targetDbErr != nil { + return nil, nil, errors.New("Error querying column types") + } + + return sourceColTypes, targetColTypes, nil +} diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index 2d47bfc..35151db 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -1,15 +1,6 @@ package main import ( - "context" - "database/sql" - "errors" - "fmt" - - "sync" - "time" - - "github.com/jackc/pgx/v5/pgxpool" _ "github.com/microsoft/go-mssqldb" log "github.com/sirupsen/logrus" ) @@ -40,160 +31,22 @@ func main() { defer targetDb.Close() for _, job := range migrationJobs { - sourceColTypes, targetColTypes, err := queryColumnTypes(sourceDb, targetDb, job) + sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job) if err != nil { log.Fatal("Unexpected error: ", err) } - log.Debugf("Source col types: %+v", sourceColTypes) - log.Debugf("Target col types: %+v", targetColTypes) + logColumnTypes(sourceColTypes, "Source col types") + logColumnTypes(targetColTypes, "Target col types") } log.Info("Migration completed successfully!") } -func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) { - query := ` -SELECT - c.name AS name, - t.name AS user_type, - CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type, - c.is_nullable AS nullable, - c.max_length AS max_length, - c.precision AS precision, - c.scale AS scale -FROM sys.columns c -JOIN sys.types t ON c.user_type_id = t.user_type_id -LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id -JOIN sys.tables st ON c.object_id = st.object_id -JOIN sys.schemas s ON st.schema_id = s.schema_id -WHERE s.name = @schema AND st.name = @table -ORDER BY c.column_id; -` +func logColumnTypes(columnTypes []ColumnType, label string) { + log.Info(label) - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table)) - if err != nil { - return nil, fmt.Errorf("Error querying column types: %w", err) + for _, col := range columnTypes { + log.Infof("%+v", col) } - defer rows.Close() - - var colTypes []ColumnType - - for rows.Next() { - var column ColumnType - - if err := rows.Scan( - &column.name, - &column.userType, - &column.systemType, - &column.nullable, - &column.maxLength, - &column.precision, - &column.scale, - ); err != nil { - return nil, fmt.Errorf("Error scanning column type results: %W", err) - } - - colTypes = append(colTypes, column) - } - - return colTypes, nil -} - -func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) { - query := ` -SELECT - c.column_name AS name, - c.data_type AS user_type, - c.udt_name AS system_type, - (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, - c.character_maximum_length AS max_length, - c.numeric_precision AS precision, - c.numeric_scale AS scale -FROM information_schema.columns c -WHERE c.table_schema = $1 AND c.table_name = $2 -ORDER BY c.ordinal_position; -` - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table) - if err != nil { - return nil, fmt.Errorf("Error querying column types: %w", err) - } - defer rows.Close() - - var colTypes []ColumnType - - for rows.Next() { - var column ColumnType - var scanMaxLength *int64 - var scanPrecision *int64 - var scanScale *int64 - - if err := rows.Scan( - &column.name, - &column.userType, - &column.systemType, - &column.nullable, - &scanMaxLength, - &scanPrecision, - &scanScale, - ); err != nil { - return nil, fmt.Errorf("Error scanning column type results: %w", err) - } - - if scanMaxLength != nil { - column.maxLength = *scanMaxLength - column.hasMaxLength = true - } else { - column.maxLength = -1 - } - - if column.systemType == "decimal" { - if scanPrecision != nil && scanScale != nil { - column.precision = *scanPrecision - column.scale = *scanScale - column.hasPrecisionScale = true - } - } - - colTypes = append(colTypes, column) - } - - return colTypes, nil -} - -func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) { - var sourceDbErr error - var targetDbErr error - var sourceColTypes []ColumnType - var targetColTypes []ColumnType - var wg sync.WaitGroup - - wg.Go(func() { - sourceColTypes, sourceDbErr = querySourceColTypes(sourceDb, migrationJob) - if sourceDbErr != nil { - log.Error("Error (sourceDb): ", sourceDbErr) - } - }) - - wg.Go(func() { - targetColTypes, targetDbErr = queryTargetColTypes(targetDb, migrationJob) - if targetDbErr != nil { - log.Error("Error (targetDb): ", targetDbErr) - } - }) - - wg.Wait() - - if sourceDbErr != nil || targetDbErr != nil { - return nil, nil, errors.New("Error querying column types") - } - - return sourceColTypes, targetColTypes, nil }