package table_analyzers import ( "context" "fmt" "strings" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type PostgresTableAnalyzer struct { db dbwrapper.DbWrapper } func NewPostgresTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { return &PostgresTableAnalyzer{db: db} } const postgresColumnMetadataQuery string = ` SELECT c.column_name AS name, c.data_type AS user_type, c.udt_name AS system_type, (CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable, COALESCE(c.character_maximum_length, -1) AS max_length, COALESCE(c.numeric_precision, -1) AS precision, COALESCE(c.numeric_scale, -1) AS scale FROM information_schema.columns c WHERE c.table_schema = $1 AND c.table_name = $2 ORDER BY c.ordinal_position;` type rawColumnPostgres struct { name string userType string systemType string nullable bool maxLength int64 precision int64 scale int64 } func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string { systemType = strings.ToLower(systemType) if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" { return "STRING" } if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" { return "INTEGER" } if systemType == "decimal" || systemType == "numeric" { return "DECIMAL" } if systemType == "float" || systemType == "real" || systemType == "double precision" { return "FLOAT" } if systemType == "bit" || systemType == "boolean" { return "BOOLEAN" } if systemType == "date" { return "DATE" } if systemType == "time" || systemType == "time without time zone" { return "TIME" } if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" { return "TIMESTAMP" } if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" { return "BINARY" } if systemType == "uniqueidentifier" || systemType == "uuid" { return "UUID" } if systemType == "json" { return "JSON" } if systemType == "geometry" || systemType == "geography" { return "GEOMETRY" } return strings.ToUpper(systemType) } func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType { const nullValue int64 = -1 stringTypes := map[string]bool{"varchar": true, "char": true, "text": true} decimalTypes := map[string]bool{"decimal": true, "numeric": true} if stringTypes[rawColumn.systemType] { rawColumn.precision, rawColumn.scale = nullValue, nullValue } else if decimalTypes[rawColumn.systemType] { rawColumn.maxLength = nullValue } else { rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue } return models.NewColumnType( rawColumn.name, rawColumn.maxLength != nullValue, rawColumn.precision != nullValue || rawColumn.scale != nullValue, rawColumn.userType, rawColumn.systemType, ta.systemTypeToUnifiedType(rawColumn.systemType), rawColumn.nullable, rawColumn.maxLength, rawColumn.precision, rawColumn.scale, ) } func (ta *PostgresTableAnalyzer) QueryColumnTypes( ctx context.Context, tableInfo config.TableInfo, ) ([]models.ColumnType, error) { localCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table) if err != nil { return nil, err } defer rows.Close() var colTypes []models.ColumnType for rows.Next() { var column rawColumnPostgres if err := rows.Scan( &column.name, &column.userType, &column.systemType, &column.nullable, &column.maxLength, &column.precision, &column.scale, ); err != nil { return nil, err } colTypes = append(colTypes, ta.rawColumnToColumnType(column)) } return colTypes, nil } func (ta *PostgresTableAnalyzer) EstimateTotalRows( ctx context.Context, tableInfo config.TableInfo, ) (int64, error) { query := ` SELECT reltuples::bigint FROM pg_class JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace WHERE pg_namespace.nspname = $1 AND pg_class.relname = $2` ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() var estimate int64 err := ta.db.QueryRow(ctxTimeout, query, tableInfo.Schema, tableInfo.Table).Scan(&estimate) if err != nil { return 0, err } if estimate < 0 { countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, tableInfo.Schema, tableInfo.Table) err = ta.db.QueryRow(ctxTimeout, countQuery).Scan(&estimate) if err != nil { return 0, err } } return estimate, nil } func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn( ctx context.Context, tableInfo config.TableInfo, columnName string, ) (etl.MaxMinColumnResult, error) { query := fmt.Sprintf(`SELECT MIN("%s"), MAX("%s") FROM "%s"."%s"`, columnName, columnName, tableInfo.Schema, tableInfo.Table) ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() result := etl.MaxMinColumnResult{} err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max) if err != nil { return etl.MaxMinColumnResult{}, err } return result, nil } func (ta *PostgresTableAnalyzer) CalculatePartitionRanges( ctx context.Context, tableInfo config.TableInfo, partitionColumn string, maxPartitions int64, rangeConstraint config.RangeConfig, ) ([]models.Partition, error) { whereClause := "" args := []any{maxPartitions} if rangeConstraint.Min != nil || rangeConstraint.Max != nil { var conditions []string if rangeConstraint.Min != nil { minOp := ">" if rangeConstraint.IsMinInclusive { minOp = ">=" } args = append(args, *rangeConstraint.Min) conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, minOp, len(args))) } if rangeConstraint.Max != nil { maxOp := "<" if rangeConstraint.IsMaxInclusive { maxOp = "<=" } args = append(args, *rangeConstraint.Max) conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, maxOp, len(args))) } whereClause = "WHERE " + strings.Join(conditions, " AND ") } query := fmt.Sprintf(` SELECT MIN("%s") AS lower_limit, MAX("%s") AS upper_limit FROM ( SELECT "%s", NTILE($1) OVER (ORDER BY "%s") AS batch_id FROM "%s"."%s" %s ) AS t GROUP BY batch_id ORDER BY batch_id`, partitionColumn, partitionColumn, partitionColumn, partitionColumn, tableInfo.Schema, tableInfo.Table, whereClause) ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() rows, err := ta.db.Query(ctxTimeout, query, args...) if err != nil { return nil, err } defer rows.Close() partitions := make([]models.Partition, 0, maxPartitions) for rows.Next() { partition := models.Partition{ Id: uuid.New(), HasRange: true, RetryCounter: 0, Range: models.PartitionRange{ IsMinInclusive: true, IsMaxInclusive: true, }, } if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil { return nil, err } partitions = append(partitions, partition) } if err := rows.Err(); err != nil { return nil, err } return partitions, nil }