package extractors import ( "context" "database/sql" "errors" "fmt" "slices" "strings" "sync" "sync/atomic" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/convert" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlExtractor struct { db *sql.DB } func NewMssqlExtractor(db *sql.DB) etl.Extractor { return &MssqlExtractor{db: db} } func buildExtractQueryMssql( tableInfo config.SourceTableInfo, columns []models.ColumnType, includeRange bool, isMinInclusive bool, ) string { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") if len(columns) == 0 { sbQuery.WriteString("*") } else { for i, col := range columns { fmt.Fprintf(&sbQuery, "[%s]", col.Name()) if col.Type() == "GEOMETRY" { fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.Name()) } if i < len(columns)-1 { sbQuery.WriteString(", ") } } } fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", tableInfo.Schema, tableInfo.Table) if includeRange { fmt.Fprintf(&sbQuery, " WHERE [%s]", tableInfo.PrimaryKey) if isMinInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } fmt.Fprintf(&sbQuery, " @min AND [%s] <= @max", tableInfo.PrimaryKey) } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", tableInfo.PrimaryKey) return sbQuery.String() } func extractorErrorFromLastRowMssql( lastRow models.UnknownRowValues, indexPrimaryKey int, batch *models.Partition, previousError error, ) *custom_errors.ExtractorError { lastIdRawValue := lastRow[indexPrimaryKey] lastId, ok := convert.ToInt64(lastIdRawValue) if !ok { currentBatch := *batch currentBatch.RetryCounter = 3 return &custom_errors.ExtractorError{ Batch: currentBatch, HasLastId: true, Msg: fmt.Sprintf("Couldn't cast last id value as int: %s", previousError.Error()), } } return &custom_errors.ExtractorError{ Batch: *batch, HasLastId: true, LastId: lastId, Msg: previousError.Error(), } } func (mssqlEx *MssqlExtractor) ProcessBatch( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, chunkSize int, batch models.Partition, indexPrimaryKey int, chChunksOut chan<- models.Batch, rowsRead *int64, ) error { query := buildExtractQueryMssql(tableInfo, columns, batch.ShouldUseRange, batch.IsLowerLimitInclusive) var queryArgs []any if batch.ShouldUseRange { queryArgs = append(queryArgs, sql.Named("min", batch.LowerLimit), sql.Named("max", batch.UpperLimit), ) } rows, err := mssqlEx.db.QueryContext(ctx, query, queryArgs...) if err != nil { return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()} } defer rows.Close() rowsChunk := make([]models.UnknownRowValues, 0, chunkSize) for rows.Next() { values := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range values { scanArgs[i] = &values[i] } if err := rows.Scan(scanArgs...); err != nil { if len(rowsChunk) == 0 { return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()} } lastRow := rowsChunk[len(rowsChunk)-1] select { case chChunksOut <- models.Batch{Id: uuid.New(), PartitionId: batch.Id, Data: rowsChunk, RetryCounter: 0}: case <-ctx.Done(): return nil } atomic.AddInt64(rowsRead, int64(len(rowsChunk))) return extractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) } rowsChunk = append(rowsChunk, values) if len(rowsChunk) >= chunkSize { select { case chChunksOut <- models.Batch{Id: uuid.New(), PartitionId: batch.Id, Data: rowsChunk, RetryCounter: 0}: case <-ctx.Done(): return nil } atomic.AddInt64(rowsRead, int64(len(rowsChunk))) rowsChunk = make([]models.UnknownRowValues, 0, chunkSize) } } if err := rows.Err(); err != nil { if errors.Is(err, ctx.Err()) { return ctx.Err() } if len(rowsChunk) == 0 { return &custom_errors.ExtractorError{Batch: batch, HasLastId: false, Msg: err.Error()} } lastRow := rowsChunk[len(rowsChunk)-1] return extractorErrorFromLastRowMssql(lastRow, indexPrimaryKey, &batch, err) } if len(rowsChunk) > 0 { select { case chChunksOut <- models.Batch{Id: uuid.New(), PartitionId: batch.Id, Data: rowsChunk, RetryCounter: 0}: case <-ctx.Done(): return nil } atomic.AddInt64(rowsRead, int64(len(rowsChunk))) } return nil } func (mssqlEx *MssqlExtractor) Exec( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, chunkSize int, chBatchesIn <-chan models.Partition, chChunksOut chan<- models.Batch, chErrorsOut chan<- custom_errors.ExtractorError, chJobErrorsOut chan<- custom_errors.JobError, wgActiveBatches *sync.WaitGroup, rowsRead *int64, ) { indexPrimaryKey := slices.IndexFunc(columns, func(col models.ColumnType) bool { return strings.EqualFold(col.Name(), tableInfo.PrimaryKey) }) if indexPrimaryKey == -1 { select { case <-ctx.Done(): return case chJobErrorsOut <- custom_errors.JobError{ ShouldCancelJob: true, Msg: "Primary key not found in provided columns", }: } return } for { if ctx.Err() != nil { return } select { case <-ctx.Done(): return case batch, ok := <-chBatchesIn: if !ok { return } err := mssqlEx.ProcessBatch( ctx, tableInfo, columns, chunkSize, batch, indexPrimaryKey, chChunksOut, rowsRead, ) if err != nil { var exError *custom_errors.ExtractorError if errors.As(err, &exError) { select { case <-ctx.Done(): return case chErrorsOut <- *exError: } } var jobError *custom_errors.JobError if errors.As(err, &jobError) { select { case <-ctx.Done(): return case chJobErrorsOut <- *jobError: } } return } wgActiveBatches.Done() } } }