package extractors import ( "context" "database/sql" "errors" "fmt" "slices" "strings" "sync" "sync/atomic" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlExtractor struct { db dbwrapper.DbWrapper } func NewMssqlExtractor(db dbwrapper.DbWrapper) etl.Extractor { return &MssqlExtractor{db: db} } func buildExtractQueryMssql( tableInfo config.SourceTableInfo, columns []models.ColumnType, includeRange bool, isMinInclusive bool, ) string { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") if len(columns) == 0 { sbQuery.WriteString("*") } else { for i, col := range columns { fmt.Fprintf(&sbQuery, "[%s]", col.Name()) if col.Type() == "GEOMETRY" { fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.Name()) } if i < len(columns)-1 { sbQuery.WriteString(", ") } } } fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", tableInfo.Schema, tableInfo.Table) if includeRange { fmt.Fprintf(&sbQuery, " WHERE [%s]", tableInfo.PrimaryKey) if isMinInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", tableInfo.PrimaryKey) } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", tableInfo.PrimaryKey) return sbQuery.String() } func (mssqlEx *MssqlExtractor) Extract( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, batchSize int, partition models.Partition, indexPrimaryKey int, chBatchesOut chan<- models.Batch, ) (int64, error) { query := buildExtractQueryMssql(tableInfo, columns, partition.HasRange, partition.Range.IsMinInclusive) var queryArgs []any if partition.HasRange { queryArgs = append(queryArgs, sql.Named("min", partition.Range.Min), sql.Named("max", partition.Range.Max)) } rows, err := mssqlEx.db.Query(ctx, query, queryArgs...) if err != nil { return 0, err } defer rows.Close() batchRows := make([]models.UnknownRowValues, 0, batchSize) var rowsRead int64 = 0 rowValues := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range rowValues { scanArgs[i] = &rowValues[i] } for rows.Next() { if err := rows.Scan(scanArgs...); err != nil { if len(batchRows) == 0 { return rowsRead, err } if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { return rowsRead, err } lastRow := batchRows[len(batchRows)-1] return rowsRead, errorFromLastPartitionRow(lastRow, indexPrimaryKey, partition, err) } rowsRead++ batchRows = append(batchRows, rowValues) if len(batchRows) >= batchSize { if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { return rowsRead, err } } } if err := flush(ctx, &partition, batchSize, batchRows, chBatchesOut); err != nil { return rowsRead, err } return rowsRead, rows.Err() } func (mssqlEx *MssqlExtractor) ExtractWithRetries( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, batchSize int, partition models.Partition, indexPrimaryKey int, chBatchesOut chan<- models.Batch, ) (int64, error) { var totalRowsRead int64 delay := time.Duration(time.Second * 1) currentParitition := partition for { rowsRead, err := mssqlEx.Extract( ctx, tableInfo, columns, batchSize, currentParitition, indexPrimaryKey, chBatchesOut, ) totalRowsRead += rowsRead if err == nil { return totalRowsRead, nil } var exError *custom_errors.ExtractorError if errors.As(err, &exError) { currentParitition.RetryCounter++ if currentParitition.RetryCounter > 3 { return totalRowsRead, &custom_errors.JobError{ Msg: fmt.Sprintf("Partition %v reached max retries", exError.Partition.Id), Prev: err, } } if exError.HasLastId { currentParitition.ParentId = exError.Partition.Id currentParitition.Id = uuid.New() currentParitition.Range.Min = exError.LastId currentParitition.Range.IsMinInclusive = false } time.Sleep(delay) continue } return totalRowsRead, err } } func (mssqlEx *MssqlExtractor) Consume( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, batchSize int, chPartitionsIn <-chan models.Partition, chBatchesOut chan<- models.Batch, chErrorsOut chan<- custom_errors.JobError, wgActivePartitions *sync.WaitGroup, rowsRead *int64, ) { indexPrimaryKey := slices.IndexFunc(columns, func(col models.ColumnType) bool { return strings.EqualFold(col.Name(), tableInfo.PrimaryKey) }) if indexPrimaryKey == -1 { select { case <-ctx.Done(): return case chErrorsOut <- custom_errors.JobError{ ShouldCancelJob: true, Msg: "Primary key not found in provided columns", }: } return } for { if ctx.Err() != nil { return } select { case <-ctx.Done(): return case partition, ok := <-chPartitionsIn: if !ok { return } rowsReadResult, err := mssqlEx.ExtractWithRetries( ctx, tableInfo, columns, batchSize, partition, indexPrimaryKey, chBatchesOut, ) wgActivePartitions.Done() if rowsReadResult > 0 { atomic.AddInt64(rowsRead, int64(rowsReadResult)) } if err != nil { var jobError *custom_errors.JobError if errors.As(err, &jobError) { select { case <-ctx.Done(): return case chErrorsOut <- *jobError: } } else { select { case <-ctx.Done(): return case chErrorsOut <- custom_errors.JobError{ShouldCancelJob: false, Msg: err.Error(), Prev: err}: } } continue } } } }