package dbwrapper import ( "context" "database/sql" "fmt" "strings" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db_dialects" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" mssql "github.com/microsoft/go-mssqldb" "github.com/sirupsen/logrus" ) func init() { Register(dbdialects.SqlServer, func() DbWrapper { return &mssqlDbWrapper{dialect: dbdialects.SqlServer} }) } type mssqlRowResult struct { row *sql.Row } func (mr *mssqlRowResult) Scan(dest ...any) error { return mr.row.Scan(dest...) } type mssqlRowsResult struct { columns []string rows *sql.Rows } func (mr *mssqlRowsResult) Close() error { return mr.rows.Close() } func (mr *mssqlRowsResult) Columns() ([]string, error) { if mr.columns != nil { return mr.columns, nil } return mr.rows.Columns() } func (mr *mssqlRowsResult) Err() error { return mr.rows.Err() } func (mr *mssqlRowsResult) Next() bool { return mr.rows.Next() } func (mr *mssqlRowsResult) Scan(dest ...any) error { return mr.rows.Scan(dest...) } func (mr *mssqlRowsResult) Values() ([]any, error) { columns, err := mr.Columns() if err != nil { return nil, err } rowValues := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range rowValues { scanArgs[i] = &rowValues[i] } if err := mr.rows.Scan(scanArgs...); err != nil { return nil, err } return rowValues, nil } type mssqlDbWrapper struct { db *sql.DB dialect string } func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { db, err := sql.Open("sqlserver", dbUrl) if err != nil { return err } if err := db.PingContext(ctx); err != nil { if err := db.Close(); err != nil { return err } return err } mw.db = db return nil } func (mw *mssqlDbWrapper) Close() error { return mw.db.Close() } func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { result, execErr := mw.db.ExecContext(ctx, query, args...) if execErr != nil { return ExecResult{}, execErr } affectedRows, err := result.RowsAffected() if err != nil { return ExecResult{}, err } return ExecResult{AffectedRows: affectedRows}, nil } func (mw *mssqlDbWrapper) GetDialect() string { return mw.dialect } func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { rows, err := mw.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &mssqlRowsResult{columns: nil, rows: rows}, nil } func (mw *mssqlDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { row := mw.db.QueryRowContext(ctx, query, args...) return &mssqlRowResult{row: row} } func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { tx, err := mw.db.BeginTx(ctx, nil) if err != nil { return 0, err } fullTableName := fmt.Sprintf("[%s].[%s]", schema, table) stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...)) if err != nil { tx.Rollback() return 0, err } for _, row := range rows { _, err = stmt.ExecContext(ctx, row...) if err != nil { stmt.Close() tx.Rollback() return 0, err } } result, err := stmt.ExecContext(ctx) if err != nil { stmt.Close() tx.Rollback() return 0, err } if err := stmt.Close(); err != nil { tx.Rollback() return 0, err } if err := tx.Commit(); err != nil { return 0, err } rowsAffected, raErr := result.RowsAffected() if raErr != nil { return 0, nil } return rowsAffected, nil } func buildExtractQueryMssql(q ExtractionQuery) (string, error) { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") hasRegularColumns := len(q.Columns) > 0 hasJsonColumns := len(q.FromJsonColumns) > 0 resolvedJson := make(map[string][]config.FromJsonItem, len(q.FromJsonColumns)) if hasJsonColumns { for _, jsonConfig := range q.FromJsonColumns { actualColumnName, err := findColumnByPattern(q.Columns, jsonConfig.Column) if err != nil { return "", err } resolvedJson[actualColumnName] = append(resolvedJson[actualColumnName], jsonConfig) } } selectParts := make([]string, 0, len(q.Columns)+len(q.FromJsonColumns)) if hasRegularColumns { for _, col := range q.Columns { jsonConfigs, isJsonColumn := resolvedJson[col.Name()] if isJsonColumn { for _, jsonConfig := range jsonConfigs { jsonPath := buildJsonPathMssql(jsonConfig.Field) jsonExpr := fmt.Sprintf("JSON_VALUE([%s], '%s') AS [%s]", col.Name(), jsonPath, col.Name()) selectParts = append(selectParts, jsonExpr) } continue } colExpr := fmt.Sprintf("[%s]", col.Name()) switch col.Type() { case "GEOMETRY": colExpr = fmt.Sprintf("[%s].STAsBinary() AS [%s]", col.Name(), col.Name()) } selectParts = append(selectParts, colExpr) } } else if !hasJsonColumns { selectParts = append(selectParts, "*") } for i, part := range selectParts { sbQuery.WriteString(part) if i < len(selectParts)-1 { sbQuery.WriteString(", ") } } fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", q.Schema, q.Table) if q.LowerLimit.IsValid || q.UpperLimit.IsValid { sbQuery.WriteString(" WHERE ") if q.LowerLimit.IsValid { fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey) if q.LowerLimit.IsInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } sbQuery.WriteString(" @min") } if q.LowerLimit.IsValid && q.UpperLimit.IsValid { sbQuery.WriteString(" AND ") } if q.UpperLimit.IsValid { fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey) if q.UpperLimit.IsInclusive { sbQuery.WriteString(" <=") } else { sbQuery.WriteString(" <") } sbQuery.WriteString(" @max") } } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", q.PrimaryKey) return sbQuery.String(), nil } func findColumnByPattern(columns []models.ColumnType, pattern string) (string, error) { if pattern == "" { return "", fmt.Errorf("column pattern cannot be empty") } if before, ok := strings.CutSuffix(pattern, "*"); ok { prefix := before for _, col := range columns { if strings.HasPrefix(col.Name(), prefix) { return col.Name(), nil } } return "", fmt.Errorf("no column found matching pattern '%s'", pattern) } for _, col := range columns { if col.Name() == pattern { return col.Name(), nil } } return "", fmt.Errorf("column '%s' not found in table columns", pattern) } func (mw *mssqlDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) { queryString, err := buildExtractQueryMssql(q) if err != nil { return nil, err } logrus.Debugf("Query: %s", queryString) var queryArgs []any if q.LowerLimit.IsValid { queryArgs = append(queryArgs, sql.Named("min", q.LowerLimit.Value)) } if q.UpperLimit.IsValid { queryArgs = append(queryArgs, sql.Named("max", q.UpperLimit.Value)) } return mw.Query(ctx, queryString, queryArgs...) } func buildJsonPathMssql(field string) string { if len(field) > 0 && field[0] == '.' { field = field[1:] } return "$." + field }