package dbwrapper import ( "context" "database/sql" "fmt" mssql "github.com/microsoft/go-mssqldb" ) func init() { Register("sqlserver", func() DbWrapper { return &mssqlDbWrapper{dialect: "sqlserver"} }) } type mssqlRowResult struct { row *sql.Row } func (mr *mssqlRowResult) Scan(dest ...any) error { return mr.row.Scan(dest...) } type mssqlRowsResult struct { columns []string rows *sql.Rows } func (mr *mssqlRowsResult) Close() error { return mr.rows.Close() } func (mr *mssqlRowsResult) Columns() ([]string, error) { if mr.columns != nil { return mr.columns, nil } return mr.rows.Columns() } func (mr *mssqlRowsResult) Err() error { return mr.rows.Err() } func (mr *mssqlRowsResult) Next() bool { return mr.rows.Next() } func (mr *mssqlRowsResult) Scan(dest ...any) error { return mr.rows.Scan(dest...) } func (mr *mssqlRowsResult) Values() ([]any, error) { columns, err := mr.Columns() if err != nil { return nil, err } rowValues := make([]any, len(columns)) scanArgs := make([]any, len(columns)) for i := range rowValues { scanArgs[i] = &rowValues[i] } if err := mr.rows.Scan(scanArgs...); err != nil { return nil, err } return rowValues, nil } type mssqlDbWrapper struct { db *sql.DB dialect string } func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { db, err := sql.Open("sqlserver", dbUrl) if err != nil { return err } if err := db.PingContext(ctx); err != nil { if err := db.Close(); err != nil { return err } return err } mw.db = db return nil } func (mw *mssqlDbWrapper) Close() error { return mw.db.Close() } func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { result, execErr := mw.db.ExecContext(ctx, query, args...) if execErr != nil { return ExecResult{}, execErr } affectedRows, err := result.RowsAffected() if err != nil { return ExecResult{}, err } return ExecResult{AffectedRows: affectedRows}, nil } func (mw *mssqlDbWrapper) GetDialect() string { return mw.dialect } func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { rows, err := mw.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &mssqlRowsResult{columns: nil, rows: rows}, nil } func (mw *mssqlDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { row := mw.db.QueryRowContext(ctx, query, args...) return &mssqlRowResult{row: row} } func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { tx, err := mw.db.BeginTx(ctx, nil) if err != nil { return 0, err } fullTableName := fmt.Sprintf("[%s].[%s]", schema, table) stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...)) if err != nil { tx.Rollback() return 0, err } for _, row := range rows { _, err = stmt.ExecContext(ctx, row...) if err != nil { stmt.Close() tx.Rollback() return 0, err } } result, err := stmt.ExecContext(ctx) if err != nil { stmt.Close() tx.Rollback() return 0, err } if err := stmt.Close(); err != nil { tx.Rollback() return 0, err } if err := tx.Commit(); err != nil { return 0, err } rowsAffected, raErr := result.RowsAffected() if raErr != nil { return 0, nil } return rowsAffected, nil }