feat: refactor column type handling and streamline database column queries
This commit is contained in:
@@ -3,12 +3,12 @@ package main
|
|||||||
type ColumnType struct {
|
type ColumnType struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
hasNullable bool
|
|
||||||
hasMaxLength bool
|
hasMaxLength bool
|
||||||
hasPrecisionScale bool
|
hasPrecisionScale bool
|
||||||
|
|
||||||
userType string
|
userType string
|
||||||
systemType string
|
systemType string
|
||||||
|
unifiedType string
|
||||||
nullable bool
|
nullable bool
|
||||||
maxLength int64
|
maxLength int64
|
||||||
precision int64
|
precision int64
|
||||||
@@ -35,6 +35,10 @@ func (c *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
|
|||||||
return c.precision, c.scale, c.hasPrecisionScale
|
return c.precision, c.scale, c.hasPrecisionScale
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ColumnType) Nullable() (nullable, ok bool) {
|
func (c *ColumnType) Nullable() bool {
|
||||||
return c.nullable, c.hasNullable
|
return c.nullable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ColumnType) Type() string {
|
||||||
|
return c.unifiedType
|
||||||
}
|
}
|
||||||
|
|||||||
279
cmd/go_migrate/inspect-columns.go
Normal file
279
cmd/go_migrate/inspect-columns.go
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
_ "github.com/microsoft/go-mssqldb"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUnifiedType(systemType string) string {
|
||||||
|
systemType = strings.ToLower(systemType)
|
||||||
|
|
||||||
|
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
|
||||||
|
return "STRING"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
|
||||||
|
return "INTEGER"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "decimal" || systemType == "numeric" {
|
||||||
|
return "DECIMAL"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "float" || systemType == "real" || systemType == "double precision" {
|
||||||
|
return "FLOAT"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "bit" || systemType == "boolean" {
|
||||||
|
return "BOOLEAN"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "date" {
|
||||||
|
return "DATE"
|
||||||
|
}
|
||||||
|
if systemType == "time" || systemType == "time without time zone" {
|
||||||
|
return "TIME"
|
||||||
|
}
|
||||||
|
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
|
||||||
|
return "TIMESTAMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
|
||||||
|
return "BINARY"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "uniqueidentifier" || systemType == "uuid" {
|
||||||
|
return "UUID"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "json" {
|
||||||
|
return "JSON"
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemType == "geometry" || systemType == "geography" {
|
||||||
|
return "GEOMETRY"
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.ToUpper(systemType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapPostgresColumn(column ColumnType, maxLength *int64, precision *int64, scale *int64) ColumnType {
|
||||||
|
stringTypes := map[string]bool{
|
||||||
|
"varchar": true, "char": true, "character": true, "text": true, "character varying": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
decimalTypes := map[string]bool{
|
||||||
|
"decimal": true, "numeric": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stringTypes[column.systemType] {
|
||||||
|
if maxLength != nil {
|
||||||
|
column.maxLength = *maxLength
|
||||||
|
column.hasMaxLength = true
|
||||||
|
} else {
|
||||||
|
column.maxLength = -1
|
||||||
|
column.hasMaxLength = false
|
||||||
|
}
|
||||||
|
column.hasPrecisionScale = false
|
||||||
|
column.precision = -1
|
||||||
|
column.scale = -1
|
||||||
|
} else if decimalTypes[column.systemType] {
|
||||||
|
column.hasMaxLength = false
|
||||||
|
column.maxLength = -1
|
||||||
|
if precision != nil && scale != nil {
|
||||||
|
column.precision = *precision
|
||||||
|
column.scale = *scale
|
||||||
|
column.hasPrecisionScale = true
|
||||||
|
} else {
|
||||||
|
column.precision = -1
|
||||||
|
column.scale = -1
|
||||||
|
column.hasPrecisionScale = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
column.hasMaxLength = false
|
||||||
|
column.maxLength = -1
|
||||||
|
column.hasPrecisionScale = false
|
||||||
|
column.precision = -1
|
||||||
|
column.scale = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
column.unifiedType = GetUnifiedType(column.systemType)
|
||||||
|
|
||||||
|
return column
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetColumnTypesPostgres(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
c.column_name AS name,
|
||||||
|
c.data_type AS user_type,
|
||||||
|
c.udt_name AS system_type,
|
||||||
|
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
|
||||||
|
c.character_maximum_length AS max_length,
|
||||||
|
c.numeric_precision AS precision,
|
||||||
|
c.numeric_scale AS scale
|
||||||
|
FROM information_schema.columns c
|
||||||
|
WHERE c.table_schema = $1 AND c.table_name = $2
|
||||||
|
ORDER BY c.ordinal_position;
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error querying column types: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var colTypes []ColumnType
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var column ColumnType
|
||||||
|
var scanMaxLength *int64
|
||||||
|
var scanPrecision *int64
|
||||||
|
var scanScale *int64
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&column.name,
|
||||||
|
&column.userType,
|
||||||
|
&column.systemType,
|
||||||
|
&column.nullable,
|
||||||
|
&scanMaxLength,
|
||||||
|
&scanPrecision,
|
||||||
|
&scanScale,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("Error scanning column type results: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
colTypes = append(colTypes, MapPostgresColumn(column, scanMaxLength, scanPrecision, scanScale))
|
||||||
|
}
|
||||||
|
|
||||||
|
return colTypes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapMssqlColumn(column ColumnType) ColumnType {
|
||||||
|
stringTypes := map[string]bool{
|
||||||
|
"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
decimalTypes := map[string]bool{
|
||||||
|
"decimal": true, "numeric": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stringTypes[column.systemType] {
|
||||||
|
column.hasMaxLength = true
|
||||||
|
if column.systemType == "nvarchar" || column.systemType == "nchar" {
|
||||||
|
if column.maxLength > 0 {
|
||||||
|
column.maxLength = column.maxLength / 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
column.hasPrecisionScale = false
|
||||||
|
column.precision = -1
|
||||||
|
column.scale = -1
|
||||||
|
} else if decimalTypes[column.systemType] {
|
||||||
|
column.hasMaxLength = false
|
||||||
|
column.maxLength = -1
|
||||||
|
column.hasPrecisionScale = true
|
||||||
|
} else {
|
||||||
|
column.hasMaxLength = false
|
||||||
|
column.maxLength = -1
|
||||||
|
column.hasPrecisionScale = false
|
||||||
|
column.precision = -1
|
||||||
|
column.scale = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
column.unifiedType = GetUnifiedType(column.systemType)
|
||||||
|
|
||||||
|
return column
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetColumnTypesMssql(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
c.name AS name,
|
||||||
|
t.name AS user_type,
|
||||||
|
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
|
||||||
|
c.is_nullable AS nullable,
|
||||||
|
c.max_length AS max_length,
|
||||||
|
c.precision AS precision,
|
||||||
|
c.scale AS scale
|
||||||
|
FROM sys.columns c
|
||||||
|
JOIN sys.types t ON c.user_type_id = t.user_type_id
|
||||||
|
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
|
||||||
|
JOIN sys.tables st ON c.object_id = st.object_id
|
||||||
|
JOIN sys.schemas s ON st.schema_id = s.schema_id
|
||||||
|
WHERE s.name = @schema AND st.name = @table
|
||||||
|
ORDER BY c.column_id;
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error querying column types: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var colTypes []ColumnType
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var column ColumnType
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&column.name,
|
||||||
|
&column.userType,
|
||||||
|
&column.systemType,
|
||||||
|
&column.nullable,
|
||||||
|
&column.maxLength,
|
||||||
|
&column.precision,
|
||||||
|
&column.scale,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("Error scanning column type results: %W", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
colTypes = append(colTypes, MapMssqlColumn(column))
|
||||||
|
}
|
||||||
|
|
||||||
|
return colTypes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
|
||||||
|
var sourceDbErr error
|
||||||
|
var targetDbErr error
|
||||||
|
var sourceColTypes []ColumnType
|
||||||
|
var targetColTypes []ColumnType
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
|
sourceColTypes, sourceDbErr = GetColumnTypesMssql(sourceDb, migrationJob)
|
||||||
|
if sourceDbErr != nil {
|
||||||
|
log.Error("Error (sourceDb): ", sourceDbErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
|
targetColTypes, targetDbErr = GetColumnTypesPostgres(targetDb, migrationJob)
|
||||||
|
if targetDbErr != nil {
|
||||||
|
log.Error("Error (targetDb): ", targetDbErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if sourceDbErr != nil || targetDbErr != nil {
|
||||||
|
return nil, nil, errors.New("Error querying column types")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceColTypes, targetColTypes, nil
|
||||||
|
}
|
||||||
@@ -1,15 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
_ "github.com/microsoft/go-mssqldb"
|
_ "github.com/microsoft/go-mssqldb"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -40,160 +31,22 @@ func main() {
|
|||||||
defer targetDb.Close()
|
defer targetDb.Close()
|
||||||
|
|
||||||
for _, job := range migrationJobs {
|
for _, job := range migrationJobs {
|
||||||
sourceColTypes, targetColTypes, err := queryColumnTypes(sourceDb, targetDb, job)
|
sourceColTypes, targetColTypes, err := GetColumnTypes(sourceDb, targetDb, job)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Unexpected error: ", err)
|
log.Fatal("Unexpected error: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Source col types: %+v", sourceColTypes)
|
logColumnTypes(sourceColTypes, "Source col types")
|
||||||
log.Debugf("Target col types: %+v", targetColTypes)
|
logColumnTypes(targetColTypes, "Target col types")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Migration completed successfully!")
|
log.Info("Migration completed successfully!")
|
||||||
}
|
}
|
||||||
|
|
||||||
func querySourceColTypes(db *sql.DB, migrationJob MigrationJob) ([]ColumnType, error) {
|
func logColumnTypes(columnTypes []ColumnType, label string) {
|
||||||
query := `
|
log.Info(label)
|
||||||
SELECT
|
|
||||||
c.name AS name,
|
|
||||||
t.name AS user_type,
|
|
||||||
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
|
|
||||||
c.is_nullable AS nullable,
|
|
||||||
c.max_length AS max_length,
|
|
||||||
c.precision AS precision,
|
|
||||||
c.scale AS scale
|
|
||||||
FROM sys.columns c
|
|
||||||
JOIN sys.types t ON c.user_type_id = t.user_type_id
|
|
||||||
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
|
|
||||||
JOIN sys.tables st ON c.object_id = st.object_id
|
|
||||||
JOIN sys.schemas s ON st.schema_id = s.schema_id
|
|
||||||
WHERE s.name = @schema AND st.name = @table
|
|
||||||
ORDER BY c.column_id;
|
|
||||||
`
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
for _, col := range columnTypes {
|
||||||
defer cancel()
|
log.Infof("%+v", col)
|
||||||
|
|
||||||
rows, err := db.QueryContext(ctx, query, sql.Named("schema", migrationJob.Schema), sql.Named("table", migrationJob.Table))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error querying column types: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var colTypes []ColumnType
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var column ColumnType
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&column.name,
|
|
||||||
&column.userType,
|
|
||||||
&column.systemType,
|
|
||||||
&column.nullable,
|
|
||||||
&column.maxLength,
|
|
||||||
&column.precision,
|
|
||||||
&column.scale,
|
|
||||||
); err != nil {
|
|
||||||
return nil, fmt.Errorf("Error scanning column type results: %W", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
colTypes = append(colTypes, column)
|
|
||||||
}
|
|
||||||
|
|
||||||
return colTypes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryTargetColTypes(db *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, error) {
|
|
||||||
query := `
|
|
||||||
SELECT
|
|
||||||
c.column_name AS name,
|
|
||||||
c.data_type AS user_type,
|
|
||||||
c.udt_name AS system_type,
|
|
||||||
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
|
|
||||||
c.character_maximum_length AS max_length,
|
|
||||||
c.numeric_precision AS precision,
|
|
||||||
c.numeric_scale AS scale
|
|
||||||
FROM information_schema.columns c
|
|
||||||
WHERE c.table_schema = $1 AND c.table_name = $2
|
|
||||||
ORDER BY c.ordinal_position;
|
|
||||||
`
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
rows, err := db.Query(ctx, query, migrationJob.Schema, migrationJob.Table)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error querying column types: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var colTypes []ColumnType
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var column ColumnType
|
|
||||||
var scanMaxLength *int64
|
|
||||||
var scanPrecision *int64
|
|
||||||
var scanScale *int64
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&column.name,
|
|
||||||
&column.userType,
|
|
||||||
&column.systemType,
|
|
||||||
&column.nullable,
|
|
||||||
&scanMaxLength,
|
|
||||||
&scanPrecision,
|
|
||||||
&scanScale,
|
|
||||||
); err != nil {
|
|
||||||
return nil, fmt.Errorf("Error scanning column type results: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if scanMaxLength != nil {
|
|
||||||
column.maxLength = *scanMaxLength
|
|
||||||
column.hasMaxLength = true
|
|
||||||
} else {
|
|
||||||
column.maxLength = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
if column.systemType == "decimal" {
|
|
||||||
if scanPrecision != nil && scanScale != nil {
|
|
||||||
column.precision = *scanPrecision
|
|
||||||
column.scale = *scanScale
|
|
||||||
column.hasPrecisionScale = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
colTypes = append(colTypes, column)
|
|
||||||
}
|
|
||||||
|
|
||||||
return colTypes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryColumnTypes(sourceDb *sql.DB, targetDb *pgxpool.Pool, migrationJob MigrationJob) ([]ColumnType, []ColumnType, error) {
|
|
||||||
var sourceDbErr error
|
|
||||||
var targetDbErr error
|
|
||||||
var sourceColTypes []ColumnType
|
|
||||||
var targetColTypes []ColumnType
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
wg.Go(func() {
|
|
||||||
sourceColTypes, sourceDbErr = querySourceColTypes(sourceDb, migrationJob)
|
|
||||||
if sourceDbErr != nil {
|
|
||||||
log.Error("Error (sourceDb): ", sourceDbErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
wg.Go(func() {
|
|
||||||
targetColTypes, targetDbErr = queryTargetColTypes(targetDb, migrationJob)
|
|
||||||
if targetDbErr != nil {
|
|
||||||
log.Error("Error (targetDb): ", targetDbErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if sourceDbErr != nil || targetDbErr != nil {
|
|
||||||
return nil, nil, errors.New("Error querying column types")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sourceColTypes, targetColTypes, nil
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user