diff --git a/cmd/go_migrate/connect.go b/cmd/go_migrate/connect.go index eb31411..40e417f 100644 --- a/cmd/go_migrate/connect.go +++ b/cmd/go_migrate/connect.go @@ -3,6 +3,7 @@ package main import ( "context" "database/sql" + "errors" "fmt" "sync" "time" @@ -51,25 +52,25 @@ func connectToDatabases() (*sql.DB, *pgxpool.Pool, error) { var sourceDb *sql.DB var targetDb *pgxpool.Pool var wg sync.WaitGroup - + wg.Go(func() { - sourceDb, sourceDbErr = connectToSqlServer() - if sourceDbErr != nil { - log.Error("Unable to connect to source db: ", sourceDbErr) - } + sourceDb, sourceDbErr = connectToSqlServer() + if sourceDbErr != nil { + log.Error("Unable to connect to source db: ", sourceDbErr) + } }) wg.Go(func() { - targetDb, targetDbErr = connectToPostgres() - if targetDbErr != nil { - log.Error("Unable to connect to target db: ", targetDbErr) - } + targetDb, targetDbErr = connectToPostgres() + if targetDbErr != nil { + log.Error("Unable to connect to target db: ", targetDbErr) + } }) wg.Wait() if sourceDbErr != nil || targetDbErr != nil { - return nil, nil, fmt.Errorf("Unable to connect to databases: %w (source), %w (target)", sourceDbErr, targetDbErr) + return nil, nil, errors.New("Unable to connect to databases") } return sourceDb, targetDb, nil diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index c4734a6..2d47bfc 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -3,12 +3,12 @@ package main import ( "context" "database/sql" + "errors" "fmt" "sync" "time" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/microsoft/go-mssqldb" log "github.com/sirupsen/logrus" @@ -52,53 +52,127 @@ func main() { log.Info("Migration completed successfully!") } -func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]sql.ColumnType, error) { - query := fmt.Sprintf(`SELECT * FROM [%s].[%s] WHERE 0 = 1`, migrationJob.Schema, migrationJob.Table) +func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) { + query := ` +SELECT + c.name AS name, + t.name AS user_type, + CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type, + c.is_nullable AS nullable, + c.max_length AS max_length, + c.precision AS precision, + c.scale AS scale +FROM sys.columns c +JOIN sys.types t ON c.user_type_id = t.user_type_id +LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id +JOIN sys.tables st ON c.object_id = st.object_id +JOIN sys.schemas s ON st.schema_id = s.schema_id +WHERE s.name = @schema AND st.name = @table +ORDER BY c.column_id; +` ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - rows, err := db.QueryContext(ctx, query) + rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table)) if err != nil { return nil, fmt.Errorf("Error querying column types: %w", err) } defer rows.Close() - colTypesPointers, err := rows.ColumnTypes() - if err != nil { - return nil, err - } + var colTypes []ColumnType - colTypes := make([]sql.ColumnType, 0, len(colTypesPointers)) - for _, c := range colTypesPointers { - colTypes = append(colTypes, *c) + for rows.Next() { + var column ColumnType + + if err := rows.Scan( + &column.name, + &column.userType, + &column.systemType, + &column.nullable, + &column.maxLength, + &column.precision, + &column.scale, + ); err != nil { + return nil, fmt.Errorf("Error scanning column type results: %W", err) + } + + colTypes = append(colTypes, column) } return colTypes, nil } -func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]pgconn.FieldDescription, error) { - query := fmt.Sprintf(`SELECT * FROM "%s"."%s" WHERE 0 = 1`, migrationJob.Schema, migrationJob.Table) +func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) { + query := ` +SELECT + c.column_name AS name, + c.data_type AS user_type, + c.udt_name AS system_type, + (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, + c.character_maximum_length AS max_length, + c.numeric_precision AS precision, + c.numeric_scale AS scale +FROM information_schema.columns c +WHERE c.table_schema = $1 AND c.table_name = $2 +ORDER BY c.ordinal_position; +` ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - rows, err := db.Query(ctx, query) + rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table) if err != nil { return nil, fmt.Errorf("Error querying column types: %w", err) } defer rows.Close() - colTypes := rows.FieldDescriptions() + var colTypes []ColumnType + + for rows.Next() { + var column ColumnType + var scanMaxLength *int64 + var scanPrecision *int64 + var scanScale *int64 + + if err := rows.Scan( + &column.name, + &column.userType, + &column.systemType, + &column.nullable, + &scanMaxLength, + &scanPrecision, + &scanScale, + ); err != nil { + return nil, fmt.Errorf("Error scanning column type results: %w", err) + } + + if scanMaxLength != nil { + column.maxLength = *scanMaxLength + column.hasMaxLength = true + } else { + column.maxLength = -1 + } + + if column.systemType == "decimal" { + if scanPrecision != nil && scanScale != nil { + column.precision = *scanPrecision + column.scale = *scanScale + column.hasPrecisionScale = true + } + } + + colTypes = append(colTypes, column) + } return colTypes, nil } -func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]sql.ColumnType, []pgconn.FieldDescription, error) { +func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) { var sourceDbErr error var targetDbErr error - var sourceColTypes []sql.ColumnType - var targetColTypes []pgconn.FieldDescription + var sourceColTypes []ColumnType + var targetColTypes []ColumnType var wg sync.WaitGroup wg.Go(func() { @@ -118,7 +192,7 @@ func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob Mig wg.Wait() if sourceDbErr != nil || targetDbErr != nil { - return nil, nil, fmt.Errorf("Unable to connect to databases: %w (source), %w (target)", sourceDbErr, targetDbErr) + return nil, nil, errors.New("Error querying column types") } return sourceColTypes, targetColTypes, nil