294 lines
7.5 KiB
Go
294 lines
7.5 KiB
Go
package table_analyzers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
|
|
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type PostgresTableAnalyzer struct {
|
|
db dbwrapper.DbWrapper
|
|
}
|
|
|
|
func NewPostgresTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer {
|
|
return &PostgresTableAnalyzer{db: db}
|
|
}
|
|
|
|
const postgresColumnMetadataQuery string = `
|
|
SELECT
|
|
c.column_name AS name,
|
|
c.data_type AS user_type,
|
|
c.udt_name AS system_type,
|
|
(CASE WHEN c.is_nullable = 'YES' THEN TRUE ELSE FALSE END) AS nullable,
|
|
COALESCE(c.character_maximum_length, -1) AS max_length,
|
|
COALESCE(c.numeric_precision, -1) AS precision,
|
|
COALESCE(c.numeric_scale, -1) AS scale
|
|
FROM information_schema.columns c
|
|
WHERE c.table_schema = $1 AND c.table_name = $2
|
|
ORDER BY c.ordinal_position;`
|
|
|
|
type rawColumnPostgres struct {
|
|
name string
|
|
userType string
|
|
systemType string
|
|
nullable bool
|
|
maxLength int64
|
|
precision int64
|
|
scale int64
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
|
|
systemType = strings.ToLower(systemType)
|
|
|
|
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
|
|
return "STRING"
|
|
}
|
|
|
|
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
|
|
return "INTEGER"
|
|
}
|
|
|
|
if systemType == "decimal" || systemType == "numeric" {
|
|
return "DECIMAL"
|
|
}
|
|
|
|
if systemType == "float" || systemType == "real" || systemType == "double precision" {
|
|
return "FLOAT"
|
|
}
|
|
|
|
if systemType == "bit" || systemType == "boolean" {
|
|
return "BOOLEAN"
|
|
}
|
|
|
|
if systemType == "date" {
|
|
return "DATE"
|
|
}
|
|
if systemType == "time" || systemType == "time without time zone" {
|
|
return "TIME"
|
|
}
|
|
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
|
|
return "TIMESTAMP"
|
|
}
|
|
|
|
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
|
|
return "BINARY"
|
|
}
|
|
|
|
if systemType == "uniqueidentifier" || systemType == "uuid" {
|
|
return "UUID"
|
|
}
|
|
|
|
if systemType == "json" {
|
|
return "JSON"
|
|
}
|
|
|
|
if systemType == "geometry" || systemType == "geography" {
|
|
return "GEOMETRY"
|
|
}
|
|
|
|
return strings.ToUpper(systemType)
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnPostgres) models.ColumnType {
|
|
const nullValue int64 = -1
|
|
stringTypes := map[string]bool{"varchar": true, "char": true, "text": true}
|
|
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
|
|
|
|
if stringTypes[rawColumn.systemType] {
|
|
rawColumn.precision, rawColumn.scale = nullValue, nullValue
|
|
} else if decimalTypes[rawColumn.systemType] {
|
|
rawColumn.maxLength = nullValue
|
|
} else {
|
|
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
|
|
}
|
|
|
|
return models.NewColumnType(
|
|
rawColumn.name,
|
|
rawColumn.maxLength != nullValue,
|
|
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
|
|
rawColumn.userType,
|
|
rawColumn.systemType,
|
|
ta.systemTypeToUnifiedType(rawColumn.systemType),
|
|
rawColumn.nullable,
|
|
rawColumn.maxLength,
|
|
rawColumn.precision,
|
|
rawColumn.scale,
|
|
)
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) QueryColumnTypes(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
) ([]models.ColumnType, error) {
|
|
localCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
|
defer cancel()
|
|
|
|
rows, err := ta.db.Query(localCtx, postgresColumnMetadataQuery, tableInfo.Schema, tableInfo.Table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var colTypes []models.ColumnType
|
|
|
|
for rows.Next() {
|
|
var column rawColumnPostgres
|
|
|
|
if err := rows.Scan(
|
|
&column.name,
|
|
&column.userType,
|
|
&column.systemType,
|
|
&column.nullable,
|
|
&column.maxLength,
|
|
&column.precision,
|
|
&column.scale,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
colTypes = append(colTypes, ta.rawColumnToColumnType(column))
|
|
}
|
|
|
|
return colTypes, nil
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) EstimateTotalRows(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
) (int64, error) {
|
|
query := `
|
|
SELECT reltuples::bigint
|
|
FROM pg_class
|
|
JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
|
|
WHERE pg_namespace.nspname = $1 AND pg_class.relname = $2`
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
|
defer cancel()
|
|
|
|
var estimate int64
|
|
err := ta.db.QueryRow(ctxTimeout, query, tableInfo.Schema, tableInfo.Table).Scan(&estimate)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if estimate < 0 {
|
|
countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."%s"`, tableInfo.Schema, tableInfo.Table)
|
|
err = ta.db.QueryRow(ctxTimeout, countQuery).Scan(&estimate)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return estimate, nil
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) QueryMaxMinFromColumn(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
columnName string,
|
|
) (etl.MaxMinColumnResult, error) {
|
|
query := fmt.Sprintf(`SELECT MIN("%s"), MAX("%s") FROM "%s"."%s"`,
|
|
columnName, columnName, tableInfo.Schema, tableInfo.Table)
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
|
defer cancel()
|
|
|
|
result := etl.MaxMinColumnResult{}
|
|
err := ta.db.QueryRow(ctxTimeout, query).Scan(&result.Min, &result.Max)
|
|
if err != nil {
|
|
return etl.MaxMinColumnResult{}, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (ta *PostgresTableAnalyzer) CalculatePartitionRanges(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
partitionColumn string,
|
|
maxPartitions int64,
|
|
rangeConstraint config.RangeConfig,
|
|
) ([]models.Partition, error) {
|
|
whereClause := ""
|
|
args := []any{maxPartitions}
|
|
|
|
if rangeConstraint.Min != nil || rangeConstraint.Max != nil {
|
|
var conditions []string
|
|
if rangeConstraint.Min != nil {
|
|
minOp := ">"
|
|
if rangeConstraint.IsMinInclusive {
|
|
minOp = ">="
|
|
}
|
|
args = append(args, *rangeConstraint.Min)
|
|
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, minOp, len(args)))
|
|
}
|
|
if rangeConstraint.Max != nil {
|
|
maxOp := "<"
|
|
if rangeConstraint.IsMaxInclusive {
|
|
maxOp = "<="
|
|
}
|
|
args = append(args, *rangeConstraint.Max)
|
|
conditions = append(conditions, fmt.Sprintf(`"%s" %s $%d`, partitionColumn, maxOp, len(args)))
|
|
}
|
|
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT MIN("%s") AS lower_limit, MAX("%s") AS upper_limit
|
|
FROM (
|
|
SELECT "%s", NTILE($1) OVER (ORDER BY "%s") AS batch_id
|
|
FROM "%s"."%s" %s
|
|
) AS t
|
|
GROUP BY batch_id
|
|
ORDER BY batch_id`,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
tableInfo.Schema,
|
|
tableInfo.Table,
|
|
whereClause)
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
|
defer cancel()
|
|
|
|
rows, err := ta.db.Query(ctxTimeout, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
partitions := make([]models.Partition, 0, maxPartitions)
|
|
|
|
for rows.Next() {
|
|
partition := models.Partition{
|
|
Id: uuid.New(),
|
|
HasRange: true,
|
|
RetryCounter: 0,
|
|
Range: models.PartitionRange{
|
|
IsMinInclusive: true,
|
|
IsMaxInclusive: true,
|
|
},
|
|
}
|
|
|
|
if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
partitions = append(partitions, partition)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return partitions, nil
|
|
}
|