feat: enhance db-wrapper with improved MSSQL and Postgres implementations; add row result handling and dialect support

This commit is contained in:
2026-04-15 22:55:14 -05:00
parent 0384d5423f
commit f09284ecdc
4 changed files with 193 additions and 33 deletions

View File

@@ -2,12 +2,17 @@ package dbwrapper
import "fmt"
const (
postgresDialect = "postgres"
mssqlDialect = "sqlserver"
)
func NewWrapper(driverType string) (DbWrapper, error) {
switch driverType {
case "postgres":
return &postgresDbWrapper{}, nil
case "sqlserver":
return &mssqlDbWrapper{}, nil
case postgresDialect:
return &postgresDbWrapper{dialect: postgresDialect}, nil
case mssqlDialect:
return &mssqlDbWrapper{dialect: mssqlDialect}, nil
default:
return nil, fmt.Errorf("driver not yet supported: %s", driverType)
}

View File

@@ -3,18 +3,86 @@ package dbwrapper
import (
"context"
"database/sql"
_ "github.com/microsoft/go-mssqldb"
)
type mssqlRowResult struct {
columns []string
rows *sql.Rows
}
func (mr *mssqlRowResult) Close() error {
return mr.rows.Close()
}
func (mr *mssqlRowResult) Columns() ([]string, error) {
if mr.columns != nil {
return mr.columns, nil
}
return mr.rows.Columns()
}
func (mr *mssqlRowResult) Err() error {
return mr.rows.Err()
}
func (mr *mssqlRowResult) Next() bool {
return mr.rows.Next()
}
func (mr *mssqlRowResult) Scan(dest ...any) error {
return mr.rows.Scan(dest...)
}
func (mr *mssqlRowResult) Values() ([]any, error) {
columns, err := mr.Columns()
if err != nil {
return nil, err
}
rowValues := make([]any, len(columns))
scanArgs := make([]any, len(columns))
for i := range rowValues {
scanArgs[i] = &rowValues[i]
}
if err := mr.rows.Scan(scanArgs...); err != nil {
return nil, err
}
return rowValues, nil
}
type mssqlDbWrapper struct {
db *sql.DB
dialect string
}
func (wrapper *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil }
func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error {
db, err := sql.Open("sqlserver", dbUrl)
if err != nil {
return err
}
func (wrapper *mssqlDbWrapper) Close() error { return nil }
if err := db.PingContext(ctx); err != nil {
if err := db.Close(); err != nil {
return err
}
return err
}
func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, execErr := wrapper.db.ExecContext(ctx, query, args...)
mw.db = db
return nil
}
func (mw *mssqlDbWrapper) Close() error {
return mw.db.Close()
}
func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, execErr := mw.db.ExecContext(ctx, query, args...)
if execErr != nil {
return ExecResult{}, execErr
}
@@ -24,15 +92,22 @@ func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...a
return ExecResult{}, err
}
return ExecResult{
AffectedRows: affectedRows,
}, nil
return ExecResult{AffectedRows: affectedRows}, nil
}
func (wrapper *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
return nil, nil
func (mw *mssqlDbWrapper) GetDialect() string {
return mw.dialect
}
func (wrapper *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
return 0, nil
func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
rows, err := mw.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &mssqlRowResult{columns: nil, rows: rows}, nil
}
func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
return 0, MethodNotSupported
}

View File

@@ -2,33 +2,108 @@ package dbwrapper
import (
"context"
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type postgresDbWrapper struct {
db *pgxpool.Pool
type postgresRowResult struct {
columns []string
rows pgx.Rows
}
func (wrapper *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil }
func (pr *postgresRowResult) Close() error {
pr.rows.Close()
return nil
}
func (wrapper *postgresDbWrapper) Close() error { return nil }
func (pr *postgresRowResult) Columns() ([]string, error) {
if pr.columns != nil {
return pr.columns, nil
}
func (wrapper *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, err := wrapper.db.Exec(ctx, query, args...)
rawColumns := pr.rows.FieldDescriptions()
if rawColumns == nil {
return nil, errors.New("error retrieving columns")
}
columns := make([]string, 0, len(rawColumns))
for _, rc := range rawColumns {
columns = append(columns, rc.Name)
}
return columns, nil
}
func (pr *postgresRowResult) Err() error {
return pr.rows.Err()
}
func (pr *postgresRowResult) Next() bool {
return pr.rows.Next()
}
func (pr *postgresRowResult) Scan(dest ...any) error {
return pr.rows.Scan(dest...)
}
func (pr *postgresRowResult) Values() ([]any, error) {
return pr.rows.Values()
}
type postgresDbWrapper struct {
db *pgxpool.Pool
dialect string
}
func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error {
pool, err := pgxpool.New(ctx, dbUrl)
if err != nil {
return err
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return err
}
pw.db = pool
return nil
}
func (pw *postgresDbWrapper) Close() error {
pw.db.Close()
return nil
}
func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
result, err := pw.db.Exec(ctx, query, args...)
if err != nil {
return ExecResult{}, err
}
return ExecResult{
AffectedRows: result.RowsAffected(),
}, nil
return ExecResult{AffectedRows: result.RowsAffected()}, nil
}
func (wrapper *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
return nil, nil
func (pw *postgresDbWrapper) GetDialect() string {
return pw.dialect
}
func (wrapper *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
return 0, nil
func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
rows, err := pw.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return &postgresRowResult{columns: nil, rows: rows}, nil
}
func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows))
if err != nil {
return 0, err
}
return affectedRows, nil
}

View File

@@ -2,24 +2,29 @@ package dbwrapper
import (
"context"
"errors"
)
var MethodNotSupported error = errors.New("Method not supported by driver... yet :P")
type ExecResult struct {
AffectedRows int64
}
type RowsResult interface {
Close()
Close() error
Columns() ([]string, error)
Err() error
Next() bool
Scan(dest ...any) error
Values() ([]any, error)
Columns() ([]string, error)
}
type DbWrapper interface {
Connect(ctx context.Context, dbUrl string) error
Close() error
Connect(ctx context.Context, dbUrl string) error
Exec(ctx context.Context, query string, args ...any) (ExecResult, error)
GetDialect() string
Query(ctx context.Context, query string, args ...any) (RowsResult, error)
SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error)
}