package dbwrapper import ( "context" "database/sql" "fmt" "strings" dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db_dialects" mssql "github.com/microsoft/go-mssqldb" ) func init() { Register(dbdialects.SqlServer, func() DbWrapper { return &mssqlDbWrapper{dialect: dbdialects.SqlServer} }) } type mssqlRowResult struct { row *sql.Row } func (mr *mssqlRowResult) Scan(dest ...any) error { return mr.row.Scan(dest...) } type mssqlRowsResult struct { columns []string rows *sql.Rows } func (mr *mssqlRowsResult) Close() error { return mr.rows.Close() } func (mr *mssqlRowsResult) Columns() ([]string, error) { if mr.columns != nil { return mr.columns, nil } return mr.rows.Columns() } func (mr *mssqlRowsResult) Err() error { return mr.rows.Err() } func (mr *mssqlRowsResult) Next() bool { return mr.rows.Next() } func (mr *mssqlRowsResult) Scan(dest ...any) error { return mr.rows.Scan(dest...) } func (mr *mssqlRowsResult) Values() ([]any, error) { columns, err := mr.Columns() if err != nil { return nil, err } rowValues := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range rowValues { scanArgs[i] = &rowValues[i] } if err := mr.rows.Scan(scanArgs...); err != nil { return nil, err } return rowValues, nil } type mssqlDbWrapper struct { db *sql.DB dialect string } func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { db, err := sql.Open("sqlserver", dbUrl) if err != nil { return err } if err := db.PingContext(ctx); err != nil { if err := db.Close(); err != nil { return err } return err } mw.db = db return nil } func (mw *mssqlDbWrapper) Close() error { return mw.db.Close() } func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { result, execErr := mw.db.ExecContext(ctx, query, args...) if execErr != nil { return ExecResult{}, execErr } affectedRows, err := result.RowsAffected() if err != nil { return ExecResult{}, err } return ExecResult{AffectedRows: affectedRows}, nil } func (mw *mssqlDbWrapper) GetDialect() string { return mw.dialect } func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { rows, err := mw.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &mssqlRowsResult{columns: nil, rows: rows}, nil } func (mw *mssqlDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { row := mw.db.QueryRowContext(ctx, query, args...) return &mssqlRowResult{row: row} } func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { tx, err := mw.db.BeginTx(ctx, nil) if err != nil { return 0, err } fullTableName := fmt.Sprintf("[%s].[%s]", schema, table) stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...)) if err != nil { tx.Rollback() return 0, err } for _, row := range rows { _, err = stmt.ExecContext(ctx, row...) if err != nil { stmt.Close() tx.Rollback() return 0, err } } result, err := stmt.ExecContext(ctx) if err != nil { stmt.Close() tx.Rollback() return 0, err } if err := stmt.Close(); err != nil { tx.Rollback() return 0, err } if err := tx.Commit(); err != nil { return 0, err } rowsAffected, raErr := result.RowsAffected() if raErr != nil { return 0, nil } return rowsAffected, nil } func (mw *mssqlDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") if len(q.Columns) == 0 { sbQuery.WriteString("*") } else { for i, col := range q.Columns { fmt.Fprintf(&sbQuery, "[%s]", col.Name()) switch col.Type() { case "GEOMETRY": fmt.Fprintf(&sbQuery, ".STAsBinary() AS [%s]", col.Name()) } if i < len(q.Columns)-1 { sbQuery.WriteString(", ") } } } fmt.Fprintf(&sbQuery, " FROM [%s].[%s]", q.Schema, q.Table) if q.LowerLimit.IsValid || q.UpperLimit.IsValid { sbQuery.WriteString(" WHERE ") if q.LowerLimit.IsValid { fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey) if q.LowerLimit.IsInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } sbQuery.WriteString(" @min") } if q.LowerLimit.IsValid && q.UpperLimit.IsValid { sbQuery.WriteString(" AND ") } if q.UpperLimit.IsValid { fmt.Fprintf(&sbQuery, "[%s]", q.PrimaryKey) if q.UpperLimit.IsInclusive { sbQuery.WriteString(" <=") } else { sbQuery.WriteString(" <") } sbQuery.WriteString(" @max") } } fmt.Fprintf(&sbQuery, " ORDER BY [%s] ASC", q.PrimaryKey) queryString := sbQuery.String() var queryArgs []any if q.LowerLimit.IsValid { queryArgs = append(queryArgs, sql.Named("min", q.LowerLimit.Value)) } if q.UpperLimit.IsValid { queryArgs = append(queryArgs, sql.Named("max", q.UpperLimit.Value)) } return mw.Query(ctx, queryString, queryArgs...) }