feat: enhance db-wrapper with improved MSSQL and Postgres implementations; add row result handling and dialect support

This commit is contained in:
2026-04-15 22:55:14 -05:00
parent 0384d5423f
commit f09284ecdc
4 changed files with 193 additions and 33 deletions

View File

@@ -2,12 +2,17 @@ package dbwrapper
import "fmt" import "fmt"
const (
postgresDialect = "postgres"
mssqlDialect = "sqlserver"
)
func NewWrapper(driverType string) (DbWrapper, error) { func NewWrapper(driverType string) (DbWrapper, error) {
switch driverType { switch driverType {
case "postgres": case postgresDialect:
return &postgresDbWrapper{}, nil return &postgresDbWrapper{dialect: postgresDialect}, nil
case "sqlserver": case mssqlDialect:
return &mssqlDbWrapper{}, nil return &mssqlDbWrapper{dialect: mssqlDialect}, nil
default: default:
return nil, fmt.Errorf("driver not yet supported: %s", driverType) return nil, fmt.Errorf("driver not yet supported: %s", driverType)
} }

View File

@@ -3,18 +3,86 @@ package dbwrapper
import ( import (
"context" "context"
"database/sql" "database/sql"
_ "github.com/microsoft/go-mssqldb"
) )
type mssqlDbWrapper struct { type mssqlRowResult struct {
db *sql.DB columns []string
rows *sql.Rows
} }
func (wrapper *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil } func (mr *mssqlRowResult) Close() error {
return mr.rows.Close()
}
func (wrapper *mssqlDbWrapper) Close() error { return nil } func (mr *mssqlRowResult) Columns() ([]string, error) {
if mr.columns != nil {
return mr.columns, nil
}
func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { return mr.rows.Columns()
result, execErr := wrapper.db.ExecContext(ctx, query, args...) }
func (mr *mssqlRowResult) Err() error {
return mr.rows.Err()
}
func (mr *mssqlRowResult) Next() bool {
return mr.rows.Next()
}
func (mr *mssqlRowResult) Scan(dest ...any) error {
return mr.rows.Scan(dest...)
}
func (mr *mssqlRowResult) Values() ([]any, error) {
columns, err := mr.Columns()
if err != nil {
return nil, err
}
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
if err := mr.rows.Scan(scanArgs...); err != nil {
return nil, err
}
return rowValues, nil
}
type mssqlDbWrapper struct {
db *sql.DB
dialect string
}
func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error {
db, err := sql.Open("sqlserver", dbUrl)
if err != nil {
return err
}
if err := db.PingContext(ctx); err != nil {
if err := db.Close(); err != nil {
return err
}
return err
}
mw.db = db
return nil
}
func (mw *mssqlDbWrapper) Close() error {
return mw.db.Close()
}
func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, execErr := mw.db.ExecContext(ctx, query, args...)
if execErr != nil { if execErr != nil {
return ExecResult{}, execErr return ExecResult{}, execErr
} }
@@ -24,15 +92,22 @@ func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...a
return ExecResult{}, err return ExecResult{}, err
} }
return ExecResult{ return ExecResult{AffectedRows: affectedRows}, nil
AffectedRows: affectedRows,
}, nil
} }
func (wrapper *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { func (mw *mssqlDbWrapper) GetDialect() string {
return nil, nil return mw.dialect
} }
func (wrapper *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
return 0, nil rows, err := mw.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &mssqlRowResult{columns: nil, rows: rows}, nil
}
func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
return 0, MethodNotSupported
} }

View File

@@ -2,33 +2,108 @@ package dbwrapper
import ( import (
"context" "context"
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
) )
type postgresDbWrapper struct { type postgresRowResult struct {
db *pgxpool.Pool columns []string
rows pgx.Rows
} }
func (wrapper *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil } func (pr *postgresRowResult) Close() error {
pr.rows.Close()
return nil
}
func (wrapper *postgresDbWrapper) Close() error { return nil } func (pr *postgresRowResult) Columns() ([]string, error) {
if pr.columns != nil {
return pr.columns, nil
}
func (wrapper *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { rawColumns := pr.rows.FieldDescriptions()
result, err := wrapper.db.Exec(ctx, query, args...) if rawColumns == nil {
return nil, errors.New("error retrieving columns")
}
columns := make([]string, 0, len(rawColumns))
for _, rc := range rawColumns {
columns = append(columns, rc.Name)
}
return columns, nil
}
func (pr *postgresRowResult) Err() error {
return pr.rows.Err()
}
func (pr *postgresRowResult) Next() bool {
return pr.rows.Next()
}
func (pr *postgresRowResult) Scan(dest ...any) error {
return pr.rows.Scan(dest...)
}
func (pr *postgresRowResult) Values() ([]any, error) {
return pr.rows.Values()
}
type postgresDbWrapper struct {
db *pgxpool.Pool
dialect string
}
func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error {
pool, err := pgxpool.New(ctx, dbUrl)
if err != nil {
return err
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return err
}
pw.db = pool
return nil
}
func (pw *postgresDbWrapper) Close() error {
pw.db.Close()
return nil
}
func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, err := pw.db.Exec(ctx, query, args...)
if err != nil { if err != nil {
return ExecResult{}, err return ExecResult{}, err
} }
return ExecResult{ return ExecResult{AffectedRows: result.RowsAffected()}, nil
AffectedRows: result.RowsAffected(),
}, nil
} }
func (wrapper *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { func (pw *postgresDbWrapper) GetDialect() string {
return nil, nil return pw.dialect
} }
func (wrapper *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
return 0, nil rows, err := pw.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return &postgresRowResult{columns: nil, rows: rows}, nil
}
func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows))
if err != nil {
return 0, err
}
return affectedRows, nil
} }

View File

@@ -2,24 +2,29 @@ package dbwrapper
import ( import (
"context" "context"
"errors"
) )
var MethodNotSupported error = errors.New("Method not supported by driver... yet :P")
type ExecResult struct { type ExecResult struct {
AffectedRows int64 AffectedRows int64
} }
type RowsResult interface { type RowsResult interface {
Close() Close() error
Columns() ([]string, error)
Err() error Err() error
Next() bool Next() bool
Scan(dest ...any) error
Values() ([]any, error) Values() ([]any, error)
Columns() ([]string, error)
} }
type DbWrapper interface { type DbWrapper interface {
Connect(ctx context.Context, dbUrl string) error
Close() error Close() error
Connect(ctx context.Context, dbUrl string) error
Exec(ctx context.Context, query string, args ...any) (ExecResult, error) Exec(ctx context.Context, query string, args ...any) (ExecResult, error)
GetDialect() string
Query(ctx context.Context, query string, args ...any) (RowsResult, error) Query(ctx context.Context, query string, args ...any) (RowsResult, error)
SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error)
} }