package extractors import ( "context" "fmt" "strings" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type PostgresExtractor struct { db dbwrapper.DbWrapper } func NewPostgresExtractor(db dbwrapper.DbWrapper) etl.Extractor { return &PostgresExtractor{db: db} } func buildExtractQueryPostgres( sourceDbInfo config.SourceTableInfo, columns []models.ColumnType, includeRange bool, isMinInclusive bool, isMaxInclusive bool, hasMin bool, hasMax bool, ) string { var sbColumns strings.Builder if len(columns) == 0 { sbColumns.WriteString("*") } else { for i, col := range columns { if col.Type() == "GEOMETRY" { sbColumns.WriteString(`ST_AsEWKB("`) sbColumns.WriteString(col.Name()) sbColumns.WriteString(`") AS "`) sbColumns.WriteString(col.Name()) sbColumns.WriteString(`"`) } else { sbColumns.WriteString(`"`) sbColumns.WriteString(col.Name()) sbColumns.WriteString(`"`) } if i < len(columns)-1 { sbColumns.WriteString(", ") } } } query := fmt.Sprintf(`SELECT %s FROM "%s"."%s"`, sbColumns.String(), sourceDbInfo.Schema, sourceDbInfo.Table) if includeRange && (hasMin || hasMax) { query += " WHERE " paramIdx := 1 if hasMin { query += fmt.Sprintf(`"%s"`, sourceDbInfo.PrimaryKey) if isMinInclusive { query += " >=" } else { query += " >" } query += fmt.Sprintf(" $%d", paramIdx) paramIdx++ } if hasMin && hasMax { query += " AND " } if hasMax { query += fmt.Sprintf(`"%s"`, sourceDbInfo.PrimaryKey) if isMaxInclusive { query += " <=" } else { query += " <" } query += fmt.Sprintf(" $%d", paramIdx) } } query += fmt.Sprintf(` ORDER BY "%s" ASC`, sourceDbInfo.PrimaryKey) return query } func (postgresEx *PostgresExtractor) Exec( ctx context.Context, tableInfo config.SourceTableInfo, columns []models.ColumnType, batchSize int, partition models.Partition, indexPrimaryKey int, chBatchesOut chan<- models.Batch, ) (int, error) { hasMin := partition.HasRange && partition.Range.Min > 0 hasMax := partition.HasRange && partition.Range.Max > 0 query := buildExtractQueryPostgres(tableInfo, columns, partition.HasRange, partition.Range.IsMinInclusive, partition.Range.IsMaxInclusive, hasMin, hasMax) var queryArgs []any if hasMin { queryArgs = append(queryArgs, partition.Range.Min) } if hasMax { queryArgs = append(queryArgs, partition.Range.Max) } rowsRead := 0 rows, err := postgresEx.db.Query(ctx, query, queryArgs...) if err != nil { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } defer rows.Close() batchRows := make([]models.UnknownRowValues, 0, batchSize) for rows.Next() { values, err := rows.Values() if err != nil { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } rowsRead++ batchRows = append(batchRows, values) if len(batchRows) >= batchSize { select { case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}: case <-ctx.Done(): return rowsRead, ctx.Err() } batchRows = make([]models.UnknownRowValues, 0, batchSize) } } if err := rows.Err(); err != nil { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } if len(batchRows) > 0 { select { case chBatchesOut <- models.Batch{Id: uuid.New(), PartitionId: partition.Id, Rows: batchRows, RetryCounter: 0}: case <-ctx.Done(): return rowsRead, nil } } return rowsRead, nil }