feat: register MSSQL and Postgres drivers in db-wrapper for improved factory pattern support

This commit is contained in:
2026-04-15 23:09:56 -05:00
parent f09284ecdc
commit ea41a7c218
3 changed files with 69 additions and 13 deletions

View File

@@ -2,18 +2,18 @@ package dbwrapper
import "fmt" import "fmt"
const ( type Factory func() DbWrapper
postgresDialect = "postgres"
mssqlDialect = "sqlserver"
)
func NewWrapper(driverType string) (DbWrapper, error) { var drivers = make(map[string]Factory)
switch driverType {
case postgresDialect: func Register(name string, factory Factory) {
return &postgresDbWrapper{dialect: postgresDialect}, nil drivers[name] = factory
case mssqlDialect: }
return &mssqlDbWrapper{dialect: mssqlDialect}, nil
default: func New(driverType string) (DbWrapper, error) {
factory, ok := drivers[driverType]
if !ok {
return nil, fmt.Errorf("driver not yet supported: %s", driverType) return nil, fmt.Errorf("driver not yet supported: %s", driverType)
} }
return factory(), nil
} }

View File

@@ -3,10 +3,17 @@ package dbwrapper
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
_ "github.com/microsoft/go-mssqldb" mssql "github.com/microsoft/go-mssqldb"
) )
func init() {
Register("sqlserver", func() DbWrapper {
return &mssqlDbWrapper{dialect: "sqlserver"}
})
}
type mssqlRowResult struct { type mssqlRowResult struct {
columns []string columns []string
rows *sql.Rows rows *sql.Rows
@@ -109,5 +116,48 @@ func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any)
} }
func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
return 0, MethodNotSupported tx, err := mw.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
fullTableName := fmt.Sprintf("[%s].[%s]", schema, table)
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...))
if err != nil {
tx.Rollback()
return 0, err
}
for _, row := range rows {
_, err = stmt.ExecContext(ctx, row...)
if err != nil {
stmt.Close()
tx.Rollback()
return 0, err
}
}
result, err := stmt.ExecContext(ctx)
if err != nil {
stmt.Close()
tx.Rollback()
return 0, err
}
if err := stmt.Close(); err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Commit(); err != nil {
return 0, err
}
rowsAffected, raErr := result.RowsAffected()
if raErr != nil {
return 0, nil
}
return rowsAffected, nil
} }

View File

@@ -8,6 +8,12 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
) )
func init() {
Register("postgres", func() DbWrapper {
return &postgresDbWrapper{dialect: "postgres"}
})
}
type postgresRowResult struct { type postgresRowResult struct {
columns []string columns []string
rows pgx.Rows rows pgx.Rows