253 lines
6.6 KiB
Go
253 lines
6.6 KiB
Go
package table_analyzers
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
|
|
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
|
|
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type MssqlTableAnalyzer struct {
|
|
db dbwrapper.DbWrapper
|
|
}
|
|
|
|
func NewMssqlTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer {
|
|
return &MssqlTableAnalyzer{db: db}
|
|
}
|
|
|
|
const mssqlColumnMetadataQuery string = `
|
|
SELECT
|
|
c.name AS name,
|
|
t.name AS user_type,
|
|
CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type,
|
|
c.is_nullable AS nullable,
|
|
c.max_length AS max_length,
|
|
c.precision AS precision,
|
|
c.scale AS scale
|
|
FROM sys.columns c
|
|
JOIN sys.types t ON c.user_type_id = t.user_type_id
|
|
LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id
|
|
JOIN sys.tables st ON c.object_id = st.object_id
|
|
JOIN sys.schemas s ON st.schema_id = s.schema_id
|
|
WHERE s.name = @schema AND st.name = @table AND (c.is_hidden = 0 OR (c.graph_type IS NOT NULL AND c.name LIKE '$%'))
|
|
ORDER BY c.column_id;`
|
|
|
|
type rawColumnMssql struct {
|
|
name string
|
|
userType string
|
|
systemType string
|
|
nullable bool
|
|
maxLength int64
|
|
precision int64
|
|
scale int64
|
|
}
|
|
|
|
func (ta *MssqlTableAnalyzer) systemTypeToUnifiedType(systemType string) string {
|
|
systemType = strings.ToLower(systemType)
|
|
|
|
if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" {
|
|
return "STRING"
|
|
}
|
|
|
|
if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" {
|
|
return "INTEGER"
|
|
}
|
|
|
|
if systemType == "decimal" || systemType == "numeric" {
|
|
return "DECIMAL"
|
|
}
|
|
|
|
if systemType == "float" || systemType == "real" || systemType == "double precision" {
|
|
return "FLOAT"
|
|
}
|
|
|
|
if systemType == "bit" || systemType == "boolean" {
|
|
return "BOOLEAN"
|
|
}
|
|
|
|
if systemType == "date" {
|
|
return "DATE"
|
|
}
|
|
if systemType == "time" || systemType == "time without time zone" {
|
|
return "TIME"
|
|
}
|
|
if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" {
|
|
return "TIMESTAMP"
|
|
}
|
|
|
|
if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" {
|
|
return "BINARY"
|
|
}
|
|
|
|
if systemType == "uniqueidentifier" || systemType == "uuid" {
|
|
return "UUID"
|
|
}
|
|
|
|
if systemType == "json" {
|
|
return "JSON"
|
|
}
|
|
|
|
if systemType == "geometry" || systemType == "geography" {
|
|
return "GEOMETRY"
|
|
}
|
|
|
|
return strings.ToUpper(systemType)
|
|
}
|
|
|
|
func (ta *MssqlTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnMssql) models.ColumnType {
|
|
const nullValue int64 = -1
|
|
stringTypes := map[string]bool{"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true}
|
|
decimalTypes := map[string]bool{"decimal": true, "numeric": true}
|
|
|
|
if stringTypes[rawColumn.systemType] {
|
|
if rawColumn.systemType == "nvarchar" || rawColumn.systemType == "nchar" {
|
|
if rawColumn.maxLength > 0 {
|
|
rawColumn.maxLength = rawColumn.maxLength / 2
|
|
}
|
|
}
|
|
|
|
rawColumn.precision, rawColumn.scale = nullValue, nullValue
|
|
} else if decimalTypes[rawColumn.systemType] {
|
|
rawColumn.maxLength = nullValue
|
|
} else {
|
|
rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue
|
|
}
|
|
|
|
columnType := models.NewColumnType(
|
|
rawColumn.name,
|
|
rawColumn.maxLength != nullValue,
|
|
rawColumn.precision != nullValue || rawColumn.scale != nullValue,
|
|
rawColumn.userType,
|
|
rawColumn.systemType,
|
|
ta.systemTypeToUnifiedType(rawColumn.systemType),
|
|
rawColumn.nullable,
|
|
rawColumn.maxLength,
|
|
rawColumn.precision,
|
|
rawColumn.scale,
|
|
)
|
|
|
|
return columnType
|
|
}
|
|
|
|
func (ta *MssqlTableAnalyzer) QueryColumnTypes(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
) ([]models.ColumnType, error) {
|
|
localCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
|
defer cancel()
|
|
|
|
rows, err := ta.db.Query(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var columnTypes []models.ColumnType
|
|
|
|
for rows.Next() {
|
|
var rawColumn rawColumnMssql
|
|
|
|
if err := rows.Scan(
|
|
&rawColumn.name,
|
|
&rawColumn.userType,
|
|
&rawColumn.systemType,
|
|
&rawColumn.nullable,
|
|
&rawColumn.maxLength,
|
|
&rawColumn.precision,
|
|
&rawColumn.scale,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
columnTypes = append(columnTypes, ta.rawColumnToColumnType(rawColumn))
|
|
}
|
|
|
|
return columnTypes, nil
|
|
}
|
|
|
|
func (ta *MssqlTableAnalyzer) EstimateTotalRows(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
) (int64, error) {
|
|
query := `
|
|
SELECT SUM(p.rows) AS count
|
|
FROM sys.tables t
|
|
JOIN sys.schemas s ON t.schema_id = s.schema_id
|
|
JOIN sys.partitions p ON t.object_id = p.object_id
|
|
WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1)
|
|
GROUP BY t.name`
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
|
|
defer cancel()
|
|
|
|
var rowsCount int64
|
|
err := ta.db.QueryRow(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return rowsCount, nil
|
|
}
|
|
|
|
func (ta *MssqlTableAnalyzer) CalculatePartitionRanges(
|
|
ctx context.Context,
|
|
tableInfo config.TableInfo,
|
|
partitionColumn string,
|
|
maxPartitions int64,
|
|
) ([]models.Partition, error) {
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
MIN([%s]) AS lower_limit,
|
|
MAX([%s]) AS upper_limit
|
|
FROM (SELECT [%s], NTILE(@maxPartitions) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T
|
|
GROUP BY batch_id
|
|
ORDER BY batch_id`,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
partitionColumn,
|
|
tableInfo.Schema,
|
|
tableInfo.Table)
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20)
|
|
defer cancel()
|
|
|
|
rows, err := ta.db.Query(ctxTimeout, query, sql.Named("maxPartitions", maxPartitions))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
partitions := make([]models.Partition, 0, maxPartitions)
|
|
|
|
for rows.Next() {
|
|
partition := models.Partition{
|
|
Id: uuid.New(),
|
|
HasRange: true,
|
|
RetryCounter: 0,
|
|
Range: models.PartitionRange{
|
|
IsMinInclusive: true,
|
|
},
|
|
}
|
|
|
|
if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
partitions = append(partitions, partition)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return partitions, nil
|
|
}
|