package table_analyzers import ( "context" "strings" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" ) type PostgresTableAnalyzer struct { db dbwrapper.DbWrapper } func NewPostgresTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { return &PostgresTableAnalyzer{db: db} } const postgresColumnMetadataQuery string = ` SELECT c.column_name AS name, c.data_type AS user_type, c.udt_name AS system_type, (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, COALESCE(c.character_maximum_length, -1) AS max_length, COALESCE(c.numeric_precision, -1) AS precision, COALESCE(c.numeric_scale, -1) AS scale FROM information_schema.columns c WHERE c.table_schema = $1 AND c.table_name = $2 ORDER BY c.ordinal_position;` type rawColumnPostgres struct { name string userType string systemType string nullable bool maxLength int64 precision int64 scale int64 } func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string { systemType = strings.ToLower(systemType) if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" { return "STRING" } if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" { return "INTEGER" } if systemType == "decimal" || systemType == "numeric" { return "DECIMAL" } if systemType == "float" || systemType == "real" || systemType == "double precision" { return "FLOAT" } if systemType == "bit" || systemType == "boolean" { return "BOOLEAN" } if systemType == "date" { return "DATE" } if systemType == "time" || systemType == "time without time zone" { return "TIME" } if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" { return "TIMESTAMP" } if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" { return "BINARY" } if systemType == "uniqueidentifier" || systemType == "uuid" { return "UUID" } if systemType == "json" { return "JSON" } if systemType == "geometry" || systemType == "geography" { return "GEOMETRY" } return strings.ToUpper(systemType) } func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType { const nullValue int64 = -1 stringTypes := map[string]bool{"varchar": true, "char": true, "text": true} decimalTypes := map[string]bool{"decimal": true, "numeric": true} if stringTypes[rawColumn.systemType] { rawColumn.precision, rawColumn.scale = nullValue, nullValue } else if decimalTypes[rawColumn.systemType] { rawColumn.maxLength = nullValue } else { rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue } return models.NewColumnType( rawColumn.name, rawColumn.maxLength != nullValue, rawColumn.precision != nullValue || rawColumn.scale != nullValue, rawColumn.userType, rawColumn.systemType, ta.systemTypeToUnifiedType(rawColumn.systemType), rawColumn.nullable, rawColumn.maxLength, rawColumn.precision, rawColumn.scale, ) } func (ta *PostgresTableAnalyzer) QueryColumnTypes( ctx context.Context, tableInfo config.TableInfo, ) ([]models.ColumnType, error) { localCtx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table) if err != nil { return nil, err } defer rows.Close() var colTypes []models.ColumnType for rows.Next() { var column rawColumnPostgres if err := rows.Scan( &column.name, &column.userType, &column.systemType, &column.nullable, &column.maxLength, &column.precision, &column.scale, ); err != nil { return nil, err } colTypes = append(colTypes, ta.rawColumnToColumnType(column)) } return colTypes, nil } func (ta *PostgresTableAnalyzer) EstimateTotalRows( ctx context.Context, tableInfo config.TableInfo, ) (int64, error) { return 0, nil } func (ta *PostgresTableAnalyzer) CalculatePartitionRanges( ctx context.Context, tableInfo config.TableInfo, partitionColumn string, maxPartitions int64, ) ([]models.Partition, error) { return []models.Partition{}, nil }