Files
go-migrate/internal/app/etl/extractors/mssql.go

127 lines
3.0 KiB
Go

package extractors
import (
"context"
"database/sql"
"fmt"
"strings"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config"
dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl"
"git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models"
)
type MssqlExtractor struct {
db dbwrapper.DbWrapper
}
func NewMssqlExtractor(db dbwrapper.DbWrapper) etl.Extractor {
return &MssqlExtractor{db: db}
}
func buildExtractQueryMssql(
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
includeRange bool,
isMinInclusive bool,
) string {
var sbQuery strings.Builder
sbQuery.WriteString("SELECT ")
if len(columns) == 0 {
sbQuery.WriteString("*")
} else {
for i, col := range columns {
fmt.Fprintf(&sbQuery, "[%s]", col.Name())
if col.Type() == "GEOMETRY" {
fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.Name())
}
if i < len(columns)-1 {
sbQuery.WriteString(", ")
}
}
}
fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", tableInfo.Schema, tableInfo.Table)
if includeRange {
fmt.Fprintf(&sbQuery, " WHERE [%s]", tableInfo.PrimaryKey)
if isMinInclusive {
sbQuery.WriteString(" >=")
} else {
sbQuery.WriteString(" >")
}
fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", tableInfo.PrimaryKey)
}
fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", tableInfo.PrimaryKey)
return sbQuery.String()
}
func (mssqlEx *MssqlExtractor) Exec(
ctx context.Context,
tableInfo config.SourceTableInfo,
columns []models.ColumnType,
batchSize int,
partition models.Partition,
indexPrimaryKey int,
chBatchesOut chan<- models.Batch,
) (int64, error) {
query := buildExtractQueryMssql(tableInfo, columns, partition.HasRange, partition.Range.IsMinInclusive)
var queryArgs []any
if partition.HasRange {
queryArgs = append(queryArgs, sql.Named("min", partition.Range.Min), sql.Named("max", partition.Range.Max))
}
rows, err := mssqlEx.db.Query(ctx, query, queryArgs...)
if err != nil {
return 0, err
}
defer rows.Close()
batchRows := make([]models.UnknownRowValues, 0, batchSize)
var rowsRead int64 = 0
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
if len(batchRows) == 0 {
return rowsRead, err
}
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
lastRow := batchRows[len(batchRows)-1]
return rowsRead, errorFromLastPartitionRow(lastRow, indexPrimaryKey, partition, err)
}
rowsRead++
batchRows = append(batchRows, rowValues)
if len(batchRows) >= batchSize {
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
}
}
if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil {
return rowsRead, err
}
return rowsRead, rows.Err()
}