From ea41a7c218bbfe084e65052a704fa6cb5e0b9bb3 Mon Sep 17 00:00:00 2001 From: Kylesoda <249518290+kylesoda@users.noreply.github.com> Date: Wed, 15 Apr 2026 23:09:56 -0500 Subject: [PATCH] feat: register MSSQL and Postgres drivers in db-wrapper for improved factory pattern support --- internal/app/db-wrapper/main.go | 22 ++++++------ internal/app/db-wrapper/mssql.go | 54 +++++++++++++++++++++++++++-- internal/app/db-wrapper/postgres.go | 6 ++++ 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/internal/app/db-wrapper/main.go b/internal/app/db-wrapper/main.go index 8e1288c..74dbea8 100644 --- a/internal/app/db-wrapper/main.go +++ b/internal/app/db-wrapper/main.go @@ -2,18 +2,18 @@ package dbwrapper import "fmt" -const ( - postgresDialect = "postgres" - mssqlDialect = "sqlserver" -) +type Factory func() DbWrapper -func NewWrapper(driverType string) (DbWrapper, error) { - switch driverType { - case postgresDialect: - return &postgresDbWrapper{dialect: postgresDialect}, nil - case mssqlDialect: - return &mssqlDbWrapper{dialect: mssqlDialect}, nil - default: +var drivers = make(map[string]Factory) + +func Register(name string, factory Factory) { + drivers[name] = factory +} + +func New(driverType string) (DbWrapper, error) { + factory, ok := drivers[driverType] + if !ok { return nil, fmt.Errorf("driver not yet supported: %s", driverType) } + return factory(), nil } diff --git a/internal/app/db-wrapper/mssql.go b/internal/app/db-wrapper/mssql.go index 88f2b03..2716d93 100644 --- a/internal/app/db-wrapper/mssql.go +++ b/internal/app/db-wrapper/mssql.go @@ -3,10 +3,17 @@ package dbwrapper import ( "context" "database/sql" + "fmt" - _ "github.com/microsoft/go-mssqldb" + mssql "github.com/microsoft/go-mssqldb" ) +func init() { + Register("sqlserver", func() DbWrapper { + return &mssqlDbWrapper{dialect: "sqlserver"} + }) +} + type mssqlRowResult struct { columns []string rows *sql.Rows @@ -109,5 +116,48 @@ func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) } func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { - return 0, MethodNotSupported + tx, err := mw.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + + fullTableName := fmt.Sprintf("[%s].[%s]", schema, table) + + stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...)) + if err != nil { + tx.Rollback() + return 0, err + } + + for _, row := range rows { + _, err = stmt.ExecContext(ctx, row...) + if err != nil { + stmt.Close() + tx.Rollback() + return 0, err + } + } + + result, err := stmt.ExecContext(ctx) + if err != nil { + stmt.Close() + tx.Rollback() + return 0, err + } + + if err := stmt.Close(); err != nil { + tx.Rollback() + return 0, err + } + + if err := tx.Commit(); err != nil { + return 0, err + } + + rowsAffected, raErr := result.RowsAffected() + if raErr != nil { + return 0, nil + } + + return rowsAffected, nil } diff --git a/internal/app/db-wrapper/postgres.go b/internal/app/db-wrapper/postgres.go index 4013a64..6f2e220 100644 --- a/internal/app/db-wrapper/postgres.go +++ b/internal/app/db-wrapper/postgres.go @@ -8,6 +8,12 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) +func init() { + Register("postgres", func() DbWrapper { + return &postgresDbWrapper{dialect: "postgres"} + }) +} + type postgresRowResult struct { columns []string rows pgx.Rows