feat: implement Postgres table analyzer with column type querying and metadata retrieval
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
|
||||
@@ -20,11 +21,154 @@ func NewMssqlTableAnalyzer(db *sql.DB) etl.TableAnalyzer {
|
||||
return &MssqlTableAnalyzer{db: db}
|
||||
}
|
||||
|
||||
const mssqlColumnMetadataQuery string = `
|
||||
SELECT
|
||||
c.name AS name,
|
||||
t.name AS user_type,
|
||||
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
|
||||
c.is_nullable AS nullable,
|
||||
c.max_length AS max_length,
|
||||
c.precision AS precision,
|
||||
c.scale AS scale
|
||||
FROM sys.columns c
|
||||
JOIN sys.types t ON c.user_type_id = t.user_type_id
|
||||
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
|
||||
JOIN sys.tables st ON c.object_id = st.object_id
|
||||
JOIN sys.schemas s ON st.schema_id = s.schema_id
|
||||
WHERE s.name = @schema AND st.name = @table
|
||||
ORDER BY c.column_id;`
|
||||
|
||||
type rawColumnMssql struct {
|
||||
name string
|
||||
userType string
|
||||
systemType string
|
||||
nullable bool
|
||||
maxLength int64
|
||||
precision int64
|
||||
scale int64
|
||||
}
|
||||
|
||||
func (ta *MssqlTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
|
||||
systemType = strings.ToLower(systemType)
|
||||
|
||||
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
|
||||
return "STRING"
|
||||
}
|
||||
|
||||
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
|
||||
return "INTEGER"
|
||||
}
|
||||
|
||||
if systemType == "decimal" || systemType == "numeric" {
|
||||
return "DECIMAL"
|
||||
}
|
||||
|
||||
if systemType == "float" || systemType == "real" || systemType == "double precision" {
|
||||
return "FLOAT"
|
||||
}
|
||||
|
||||
if systemType == "bit" || systemType == "boolean" {
|
||||
return "BOOLEAN"
|
||||
}
|
||||
|
||||
if systemType == "date" {
|
||||
return "DATE"
|
||||
}
|
||||
if systemType == "time" || systemType == "time without time zone" {
|
||||
return "TIME"
|
||||
}
|
||||
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
|
||||
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
|
||||
return "BINARY"
|
||||
}
|
||||
|
||||
if systemType == "uniqueidentifier" || systemType == "uuid" {
|
||||
return "UUID"
|
||||
}
|
||||
|
||||
if systemType == "json" {
|
||||
return "JSON"
|
||||
}
|
||||
|
||||
if systemType == "geometry" || systemType == "geography" {
|
||||
return "GEOMETRY"
|
||||
}
|
||||
|
||||
return strings.ToUpper(systemType)
|
||||
}
|
||||
|
||||
func (ta *MssqlTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnMssql) models.ColumnType {
|
||||
const nullValue int64 = -1
|
||||
stringTypes := map[string]bool{"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true}
|
||||
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
|
||||
|
||||
if stringTypes[rawColumn.systemType] {
|
||||
if rawColumn.systemType == "nvarchar" || rawColumn.systemType == "nchar" {
|
||||
if rawColumn.maxLength > 0 {
|
||||
rawColumn.maxLength = rawColumn.maxLength / 2
|
||||
}
|
||||
}
|
||||
|
||||
rawColumn.precision, rawColumn.scale = nullValue, nullValue
|
||||
} else if decimalTypes[rawColumn.systemType] {
|
||||
rawColumn.maxLength = nullValue
|
||||
} else {
|
||||
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
|
||||
}
|
||||
|
||||
columnType := models.NewColumnType(
|
||||
rawColumn.name,
|
||||
rawColumn.maxLength != nullValue,
|
||||
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
|
||||
rawColumn.userType,
|
||||
rawColumn.systemType,
|
||||
ta.systemTypeToUnifiedType(rawColumn.systemType),
|
||||
rawColumn.nullable,
|
||||
rawColumn.maxLength,
|
||||
rawColumn.precision,
|
||||
rawColumn.scale,
|
||||
)
|
||||
|
||||
return columnType
|
||||
}
|
||||
|
||||
func (ta *MssqlTableAnalyzer) QueryColumnTypes(
|
||||
ctx context.Context,
|
||||
tableInfo config.TableInfo,
|
||||
) ([]models.ColumnType, error) {
|
||||
return []models.ColumnType{}, nil
|
||||
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := ta.db.QueryContext(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columnTypes []models.ColumnType
|
||||
|
||||
for rows.Next() {
|
||||
var rawColumn rawColumnMssql
|
||||
|
||||
if err := rows.Scan(
|
||||
&rawColumn.name,
|
||||
&rawColumn.userType,
|
||||
&rawColumn.systemType,
|
||||
&rawColumn.nullable,
|
||||
&rawColumn.maxLength,
|
||||
&rawColumn.precision,
|
||||
&rawColumn.scale,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnTypes = append(columnTypes, ta.rawColumnToColumnType(rawColumn))
|
||||
}
|
||||
|
||||
return columnTypes, nil
|
||||
}
|
||||
|
||||
func (ta *MssqlTableAnalyzer) EstimateTotalRows(
|
||||
|
||||
174
internal/app/etl/table_analyzers/postgres.go
Normal file
174
internal/app/etl/table_analyzers/postgres.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package table_analyzers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
|
||||
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type PostgresTableAnalyzer struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPostgresTableAnalyzer(db *pgxpool.Pool) etl.TableAnalyzer {
|
||||
return &PostgresTableAnalyzer{db: db}
|
||||
}
|
||||
|
||||
const postgresColumnMetadataQuery string = `
|
||||
SELECT
|
||||
c.column_name AS name,
|
||||
c.data_type AS user_type,
|
||||
c.udt_name AS system_type,
|
||||
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
|
||||
COALESCE(c.character_maximum_length, -1) AS max_length,
|
||||
COALESCE(c.numeric_precision, -1) AS precision,
|
||||
COALESCE(c.numeric_scale, -1) AS scale
|
||||
FROM information_schema.columns c
|
||||
WHERE c.table_schema = $1 AND c.table_name = $2
|
||||
ORDER BY c.ordinal_position;`
|
||||
|
||||
type rawColumnPostgres struct {
|
||||
name string
|
||||
userType string
|
||||
systemType string
|
||||
nullable bool
|
||||
maxLength int64
|
||||
precision int64
|
||||
scale int64
|
||||
}
|
||||
|
||||
func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
|
||||
systemType = strings.ToLower(systemType)
|
||||
|
||||
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
|
||||
return "STRING"
|
||||
}
|
||||
|
||||
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
|
||||
return "INTEGER"
|
||||
}
|
||||
|
||||
if systemType == "decimal" || systemType == "numeric" {
|
||||
return "DECIMAL"
|
||||
}
|
||||
|
||||
if systemType == "float" || systemType == "real" || systemType == "double precision" {
|
||||
return "FLOAT"
|
||||
}
|
||||
|
||||
if systemType == "bit" || systemType == "boolean" {
|
||||
return "BOOLEAN"
|
||||
}
|
||||
|
||||
if systemType == "date" {
|
||||
return "DATE"
|
||||
}
|
||||
if systemType == "time" || systemType == "time without time zone" {
|
||||
return "TIME"
|
||||
}
|
||||
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
|
||||
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
|
||||
return "BINARY"
|
||||
}
|
||||
|
||||
if systemType == "uniqueidentifier" || systemType == "uuid" {
|
||||
return "UUID"
|
||||
}
|
||||
|
||||
if systemType == "json" {
|
||||
return "JSON"
|
||||
}
|
||||
|
||||
if systemType == "geometry" || systemType == "geography" {
|
||||
return "GEOMETRY"
|
||||
}
|
||||
|
||||
return strings.ToUpper(systemType)
|
||||
}
|
||||
|
||||
func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType {
|
||||
const nullValue int64 = -1
|
||||
stringTypes := map[string]bool{"varchar": true, "char": true, "text": true}
|
||||
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
|
||||
|
||||
if stringTypes[rawColumn.systemType] {
|
||||
rawColumn.precision, rawColumn.scale = nullValue, nullValue
|
||||
} else if decimalTypes[rawColumn.systemType] {
|
||||
rawColumn.maxLength = nullValue
|
||||
} else {
|
||||
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
|
||||
}
|
||||
|
||||
return models.NewColumnType(
|
||||
rawColumn.name,
|
||||
rawColumn.maxLength != nullValue,
|
||||
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
|
||||
rawColumn.userType,
|
||||
rawColumn.systemType,
|
||||
ta.systemTypeToUnifiedType(rawColumn.systemType),
|
||||
rawColumn.nullable,
|
||||
rawColumn.maxLength,
|
||||
rawColumn.precision,
|
||||
rawColumn.scale,
|
||||
)
|
||||
}
|
||||
|
||||
func (ta *PostgresTableAnalyzer) QueryColumnTypes(
|
||||
ctx context.Context,
|
||||
tableInfo config.TableInfo,
|
||||
) ([]models.ColumnType, error) {
|
||||
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var colTypes []models.ColumnType
|
||||
|
||||
for rows.Next() {
|
||||
var column rawColumnPostgres
|
||||
|
||||
if err := rows.Scan(
|
||||
&column.name,
|
||||
&column.userType,
|
||||
&column.systemType,
|
||||
&column.nullable,
|
||||
&column.maxLength,
|
||||
&column.precision,
|
||||
&column.scale,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colTypes = append(colTypes, ta.rawColumnToColumnType(column))
|
||||
}
|
||||
|
||||
return colTypes, nil
|
||||
}
|
||||
|
||||
func (ta *PostgresTableAnalyzer) EstimateTotalRows(
|
||||
ctx context.Context,
|
||||
tableInfo config.TableInfo,
|
||||
) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
|
||||
ctx context.Context,
|
||||
tableInfo config.TableInfo,
|
||||
partitionColumn string,
|
||||
maxPartitions int64,
|
||||
) ([]models.Partition, error) {
|
||||
return []models.Partition{}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user