From f09284ecdc0f518ac643e4507f8563da1376934b Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Wed, 15 Apr 2026 22:55:14 -0500 Subject: [PATCH] feat: enhance db-wrapper with improved MSSQL and Postgres implementations; add row result handling and dialect support --- internal/app/db-wrapper/main.go | 13 ++-- internal/app/db-wrapper/mssql.go | 101 ++++++++++++++++++++++++---- internal/app/db-wrapper/postgres.go | 101 ++++++++++++++++++++++++---- internal/app/db-wrapper/types.go | 11 ++- 4 files changed, 193 insertions(+), 33 deletions(-) diff --git a/internal/app/db-wrapper/main.go b/internal/app/db-wrapper/main.go index e6e5e83..8e1288c 100644 --- a/internal/app/db-wrapper/main.go +++ b/internal/app/db-wrapper/main.go @@ -2,12 +2,17 @@ package dbwrapper import "fmt" +const ( + postgresDialect = "postgres" + mssqlDialect = "sqlserver" +) + func NewWrapper(driverType string) (DbWrapper, error) { switch driverType { - case "postgres": - return &postgresDbWrapper{}, nil - case "sqlserver": - return &mssqlDbWrapper{}, nil + case postgresDialect: + return &postgresDbWrapper{dialect: postgresDialect}, nil + case mssqlDialect: + return &mssqlDbWrapper{dialect: mssqlDialect}, nil default: return nil, fmt.Errorf("driver not yet supported: %s", driverType) } diff --git a/internal/app/db-wrapper/mssql.go b/internal/app/db-wrapper/mssql.go index 1682b17..88f2b03 100644 --- a/internal/app/db-wrapper/mssql.go +++ b/internal/app/db-wrapper/mssql.go @@ -3,18 +3,86 @@ package dbwrapper import ( "context" "database/sql" + + _ "github.com/microsoft/go-mssqldb" ) -type mssqlDbWrapper struct { - db *sql.DB +type mssqlRowResult struct { + columns []string + rows *sql.Rows } -func (wrapper *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil } +func (mr *mssqlRowResult) Close() error { + return mr.rows.Close() +} -func (wrapper *mssqlDbWrapper) Close() error { return nil } +func (mr *mssqlRowResult) Columns() ([]string, error) { + if mr.columns != nil { + return mr.columns, nil + } -func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { - result, execErr := wrapper.db.ExecContext(ctx, query, args...) + return mr.rows.Columns() +} + +func (mr *mssqlRowResult) Err() error { + return mr.rows.Err() +} + +func (mr *mssqlRowResult) Next() bool { + return mr.rows.Next() +} + +func (mr *mssqlRowResult) Scan(dest ...any) error { + return mr.rows.Scan(dest...) +} + +func (mr *mssqlRowResult) Values() ([]any, error) { + columns, err := mr.Columns() + if err != nil { + return nil, err + } + + rowValues := make([]any, len(columns)) + scanArgs := make([]any, len(columns)) + for i := range rowValues { + scanArgs[i] = &rowValues[i] + } + + if err := mr.rows.Scan(scanArgs...); err != nil { + return nil, err + } + + return rowValues, nil +} + +type mssqlDbWrapper struct { + db *sql.DB + dialect string +} + +func (mw *mssqlDbWrapper) Connect(ctx context.Context, dbUrl string) error { + db, err := sql.Open("sqlserver", dbUrl) + if err != nil { + return err + } + + if err := db.PingContext(ctx); err != nil { + if err := db.Close(); err != nil { + return err + } + return err + } + + mw.db = db + return nil +} + +func (mw *mssqlDbWrapper) Close() error { + return mw.db.Close() +} + +func (mw *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { + result, execErr := mw.db.ExecContext(ctx, query, args...) if execErr != nil { return ExecResult{}, execErr } @@ -24,15 +92,22 @@ func (wrapper *mssqlDbWrapper) Exec(ctx context.Context, query string, args ...a return ExecResult{}, err } - return ExecResult{ - AffectedRows: affectedRows, - }, nil + return ExecResult{AffectedRows: affectedRows}, nil } -func (wrapper *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { - return nil, nil +func (mw *mssqlDbWrapper) GetDialect() string { + return mw.dialect } -func (wrapper *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { - return 0, nil +func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { + rows, err := mw.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + + return &mssqlRowResult{columns: nil, rows: rows}, nil +} + +func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { + return 0, MethodNotSupported } diff --git a/internal/app/db-wrapper/postgres.go b/internal/app/db-wrapper/postgres.go index ac99093..4013a64 100644 --- a/internal/app/db-wrapper/postgres.go +++ b/internal/app/db-wrapper/postgres.go @@ -2,33 +2,108 @@ package dbwrapper import ( "context" + "errors" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) -type postgresDbWrapper struct { - db *pgxpool.Pool +type postgresRowResult struct { + columns []string + rows pgx.Rows } -func (wrapper *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { return nil } +func (pr *postgresRowResult) Close() error { + pr.rows.Close() + return nil +} -func (wrapper *postgresDbWrapper) Close() error { return nil } +func (pr *postgresRowResult) Columns() ([]string, error) { + if pr.columns != nil { + return pr.columns, nil + } -func (wrapper *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { - result, err := wrapper.db.Exec(ctx, query, args...) + rawColumns := pr.rows.FieldDescriptions() + if rawColumns == nil { + return nil, errors.New("error retrieving columns") + } + + columns := make([]string, 0, len(rawColumns)) + for _, rc := range rawColumns { + columns = append(columns, rc.Name) + } + + return columns, nil +} + +func (pr *postgresRowResult) Err() error { + return pr.rows.Err() +} + +func (pr *postgresRowResult) Next() bool { + return pr.rows.Next() +} + +func (pr *postgresRowResult) Scan(dest ...any) error { + return pr.rows.Scan(dest...) +} + +func (pr *postgresRowResult) Values() ([]any, error) { + return pr.rows.Values() +} + +type postgresDbWrapper struct { + db *pgxpool.Pool + dialect string +} + +func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { + pool, err := pgxpool.New(ctx, dbUrl) + if err != nil { + return err + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return err + } + + pw.db = pool + return nil +} + +func (pw *postgresDbWrapper) Close() error { + pw.db.Close() + return nil +} + +func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { + result, err := pw.db.Exec(ctx, query, args...) if err != nil { return ExecResult{}, err } - return ExecResult{ - AffectedRows: result.RowsAffected(), - }, nil + return ExecResult{AffectedRows: result.RowsAffected()}, nil } -func (wrapper *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { - return nil, nil +func (pw *postgresDbWrapper) GetDialect() string { + return pw.dialect } -func (wrapper *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { - return 0, nil +func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { + rows, err := pw.db.Query(ctx, query, args...) + if err != nil { + return nil, err + } + + return &postgresRowResult{columns: nil, rows: rows}, nil +} + +func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { + affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows)) + if err != nil { + return 0, err + } + + return affectedRows, nil } diff --git a/internal/app/db-wrapper/types.go b/internal/app/db-wrapper/types.go index a10ca21..421294e 100644 --- a/internal/app/db-wrapper/types.go +++ b/internal/app/db-wrapper/types.go @@ -2,24 +2,29 @@ package dbwrapper import ( "context" + "errors" ) +var MethodNotSupported error = errors.New("Method not supported by driver... yet :P") + type ExecResult struct { AffectedRows int64 } type RowsResult interface { - Close() + Close() error + Columns() ([]string, error) Err() error Next() bool + Scan(dest ...any) error Values() ([]any, error) - Columns() ([]string, error) } type DbWrapper interface { - Connect(ctx context.Context, dbUrl string) error Close() error + Connect(ctx context.Context, dbUrl string) error Exec(ctx context.Context, query string, args ...any) (ExecResult, error) + GetDialect() string Query(ctx context.Context, query string, args ...any) (RowsResult, error) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) }