package extractors import ( "context" "database/sql" "errors" "fmt" "strings" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/convert" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlExtractor struct { db dbwrapper.DbWrapper } func NewMssqlExtractor(db dbwrapper.DbWrapper) etl.Extractor { return &MssqlExtractor{db: db} } func buildExtractQueryMssql( tableInfo config.SourceTableInfo, columns []models.ColumnType, includeRange bool, isMinInclusive bool, isMaxInclusive bool, ) string { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") if len(columns) == 0 { sbQuery.WriteString("*") } else { for i, col := range columns { fmt.Fprintf(&sbQuery, "[%s]", col.Name()) if col.Type() == "GEOMETRY" { fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.Name()) } if i < len(columns)-1 { sbQuery.WriteString(", ") } } } fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", tableInfo.Schema, tableInfo.Table) if includeRange { fmt.Fprintf(&sbQuery, " WHERE [%s]", tableInfo.PrimaryKey) if isMinInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } sbQuery.WriteString(" @min AND ") fmt.Fprintf(&sbQuery, "[%s]", tableInfo.PrimaryKey) if isMaxInclusive { sbQuery.WriteString(" <=") } else { sbQuery.WriteString(" <") } sbQuery.WriteString(" @max") } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", tableInfo.PrimaryKey) return sbQuery.String() } func errorFromLastRow( lastRow models.UnknownRowValues, indexPrimaryKey int, partition models.Partition, previousError error, ) *custom_errors.ExtractorError { lastIdRawValue := lastRow[indexPrimaryKey] lastId, ok := convert.ToInt64(lastIdRawValue) if !ok { currentPartition := partition currentPartition.RetryCounter = 3 return &custom_errors.ExtractorError{ Partition: currentPartition, HasLastId: true, Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()), } } return &custom_errors.ExtractorError{ Partition: partition, HasLastId: true, LastId: lastId, Msg: previousError.Error(), } } func (mssqlEx *MssqlExtractor) Exec( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, batchSize int, partition models.Partition, indexPrimaryKey int, chBatchesOut chan<- models.Batch, ) (int, error) { query := buildExtractQueryMssql(tableInfo, columns, partition.HasRange, partition.Range.IsMinInclusive, partition.Range.IsMaxInclusive) var queryArgs []any if partition.HasRange { queryArgs = append(queryArgs, sql.Named("min", partition.Range.Min), sql.Named("max", partition.Range.Max), ) } rowsRead := 0 rows, err := mssqlEx.db.Query(ctx, query, queryArgs...) if err != nil { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } defer rows.Close() batchRows := make([]models.UnknownRowValues, 0, batchSize) for rows.Next() { rowValues := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range rowValues { scanArgs[i] = &rowValues[i] } if err := rows.Scan(scanArgs...); err != nil { if len(batchRows) == 0 { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } lastRow := batchRows[len(batchRows)-1] select { case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}: case <-ctx.Done(): return rowsRead, ctx.Err() } return rowsRead, errorFromLastRow(lastRow, indexPrimaryKey, partition, err) } rowsRead++ batchRows = append(batchRows, rowValues) if len(batchRows) >= batchSize { select { case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}: case <-ctx.Done(): return rowsRead, ctx.Err() } batchRows = make([]models.UnknownRowValues, 0, batchSize) } } if err := rows.Err(); err != nil { if errors.Is(err, ctx.Err()) { return rowsRead, ctx.Err() } if len(batchRows) > 0 { lastRow := batchRows[len(batchRows)-1] return rowsRead, errorFromLastRow(lastRow, indexPrimaryKey, partition, err) } return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } if len(batchRows) > 0 { select { case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}: case <-ctx.Done(): return rowsRead, ctx.Err() } } return rowsRead, nil }