package table_analyzers import ( "context" "database/sql" "fmt" "strings" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlTableAnalyzer struct { db dbwrapper.DbWrapper } func NewMssqlTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { return &MssqlTableAnalyzer{db: db} } const mssqlColumnMetadataQuery string = ` SELECT c.name AS name, t.name AS user_type, CASE WHEN t.is_user_defined = 0 THEN t.name ELSE bt.name END AS system_type, c.is_nullable AS nullable, c.max_length AS max_length, c.precision AS precision, c.scale AS scale FROM sys.columns c JOIN sys.types t ON c.user_type_id = t.user_type_id LEFT JOIN sys.types bt ON t.is_user_defined = 1 AND bt.user_type_id = t.system_type_id JOIN sys.tables st ON c.object_id = st.object_id JOIN sys.schemas s ON st.schema_id = s.schema_id WHERE s.name = @schema AND st.name = @table AND c.name NOT LIKE 'graph_id%' AND c.name NOT LIKE 'node_id%' AND c.name NOT LIKE 'edge_id%' ORDER BY c.column_id;` type rawColumnMssql struct { name string userType string systemType string nullable bool maxLength int64 precision int64 scale int64 } func (ta *MssqlTableAnalyzer) systemTypeToUnifiedType(systemType string) string { systemType = strings.ToLower(systemType) if systemType == "varchar" || systemType == "char" || systemType == "nvarchar" || systemType == "nchar" || systemType == "text" || systemType == "ntext" { return "STRING" } if systemType == "int" || systemType == "int4" || systemType == "integer" || systemType == "smallint" || systemType == "int2" || systemType == "bigint" || systemType == "int8" || systemType == "tinyint" { return "INTEGER" } if systemType == "decimal" || systemType == "numeric" { return "DECIMAL" } if systemType == "float" || systemType == "real" || systemType == "double precision" { return "FLOAT" } if systemType == "bit" || systemType == "boolean" { return "BOOLEAN" } if systemType == "date" { return "DATE" } if systemType == "time" || systemType == "time without time zone" { return "TIME" } if systemType == "datetime" || systemType == "datetime2" || systemType == "timestamp" || systemType == "timestamptz" || systemType == "timestamp with time zone" { return "TIMESTAMP" } if systemType == "binary" || systemType == "varbinary" || systemType == "image" || systemType == "bytea" { return "BINARY" } if systemType == "uniqueidentifier" || systemType == "uuid" { return "UUID" } if systemType == "json" { return "JSON" } if systemType == "geometry" || systemType == "geography" { return "GEOMETRY" } return strings.ToUpper(systemType) } func (ta *MssqlTableAnalyzer) rawColumnToColumnType(rawColumn rawColumnMssql) models.ColumnType { const nullValue int64 = -1 stringTypes := map[string]bool{"varchar": true, "char": true, "nvarchar": true, "nchar": true, "text": true, "ntext": true} decimalTypes := map[string]bool{"decimal": true, "numeric": true} if stringTypes[rawColumn.systemType] { if rawColumn.systemType == "nvarchar" || rawColumn.systemType == "nchar" { if rawColumn.maxLength > 0 { rawColumn.maxLength = rawColumn.maxLength / 2 } } rawColumn.precision, rawColumn.scale = nullValue, nullValue } else if decimalTypes[rawColumn.systemType] { rawColumn.maxLength = nullValue } else { rawColumn.maxLength, rawColumn.precision, rawColumn.scale = nullValue, nullValue, nullValue } columnType := models.NewColumnType( rawColumn.name, rawColumn.maxLength != nullValue, rawColumn.precision != nullValue || rawColumn.scale != nullValue, rawColumn.userType, rawColumn.systemType, ta.systemTypeToUnifiedType(rawColumn.systemType), rawColumn.nullable, rawColumn.maxLength, rawColumn.precision, rawColumn.scale, ) return columnType } func (ta *MssqlTableAnalyzer) QueryColumnTypes( ctx context.Context, tableInfo config.TableInfo, ) ([]models.ColumnType, error) { localCtx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() rows, err := ta.db.Query(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)) if err != nil { return nil, err } defer rows.Close() var columnTypes []models.ColumnType for rows.Next() { var rawColumn rawColumnMssql if err := rows.Scan( &rawColumn.name, &rawColumn.userType, &rawColumn.systemType, &rawColumn.nullable, &rawColumn.maxLength, &rawColumn.precision, &rawColumn.scale, ); err != nil { return nil, err } columnTypes = append(columnTypes, ta.rawColumnToColumnType(rawColumn)) } return columnTypes, nil } func (ta *MssqlTableAnalyzer) EstimateTotalRows( ctx context.Context, tableInfo config.TableInfo, ) (int64, error) { query := ` SELECT SUM(p.rows) AS count FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id JOIN sys.partitions p ON t.object_id = p.object_id WHERE s.name = @schema AND t.name = @table AND p.index_id IN (0, 1) GROUP BY t.name` ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() var rowsCount int64 err := ta.db.QueryRow(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount) if err != nil { return 0, err } return rowsCount, nil } func (ta *MssqlTableAnalyzer) CalculatePartitionRanges( ctx context.Context, tableInfo config.TableInfo, partitionColumn string, maxPartitions int64, ) ([]models.Partition, error) { query := fmt.Sprintf(` SELECT MIN([%s]) AS lower_limit, MAX([%s]) AS upper_limit FROM (SELECT [%s], NTILE(@maxPartitions) OVER (ORDER BY [%s]) AS batch_id FROM [%s].[%s]) AS T GROUP BY batch_id ORDER BY batch_id`, partitionColumn, partitionColumn, partitionColumn, partitionColumn, tableInfo.Schema, tableInfo.Table) ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() rows, err := ta.db.Query(ctxTimeout, query, sql.Named("maxPartitions", maxPartitions)) if err != nil { return nil, err } defer rows.Close() partitions := make([]models.Partition, 0, maxPartitions) for rows.Next() { partition := models.Partition{ Id: uuid.New(), HasRange: true, RetryCounter: 0, Range: models.PartitionRange{ IsMinInclusive: true, }, } if err := rows.Scan(&partition.Range.Min, &partition.Range.Max); err != nil { return nil, err } partitions = append(partitions, partition) } if err := rows.Err(); err != nil { return nil, err } return partitions, nil }