feat: register MSSQL and Postgres drivers in db-wrapper for improved factory pattern support
This commit is contained in:
@@ -2,18 +2,18 @@ package dbwrapper
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
postgresDialect = "postgres"
|
||||
mssqlDialect = "sqlserver"
|
||||
)
|
||||
type Factory func() DbWrapper
|
||||
|
||||
func NewWrapper(driverType string) (DbWrapper, error) {
|
||||
switch driverType {
|
||||
case postgresDialect:
|
||||
return &postgresDbWrapper{dialect: postgresDialect}, nil
|
||||
case mssqlDialect:
|
||||
return &mssqlDbWrapper{dialect: mssqlDialect}, nil
|
||||
default:
|
||||
var drivers = make(map[string]Factory)
|
||||
|
||||
func Register(name string, factory Factory) {
|
||||
drivers[name] = factory
|
||||
}
|
||||
|
||||
func New(driverType string) (DbWrapper, error) {
|
||||
factory, ok := drivers[driverType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("driver not yet supported: %s", driverType)
|
||||
}
|
||||
return factory(), nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,17 @@ package dbwrapper
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb"
|
||||
mssql "github.com/microsoft/go-mssqldb"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("sqlserver", func() DbWrapper {
|
||||
return &mssqlDbWrapper{dialect: "sqlserver"}
|
||||
})
|
||||
}
|
||||
|
||||
type mssqlRowResult struct {
|
||||
columns []string
|
||||
rows *sql.Rows
|
||||
@@ -109,5 +116,48 @@ func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any)
|
||||
}
|
||||
|
||||
func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
|
||||
return 0, MethodNotSupported
|
||||
tx, err := mw.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
fullTableName := fmt.Sprintf("[%s].[%s]", schema, table)
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(fullTableName, mssql.BulkOptions{}, columnNames...))
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
_, err = stmt.ExecContext(ctx, row...)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
result, err := stmt.ExecContext(ctx)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rowsAffected, raErr := result.RowsAffected()
|
||||
if raErr != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return rowsAffected, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,12 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("postgres", func() DbWrapper {
|
||||
return &postgresDbWrapper{dialect: "postgres"}
|
||||
})
|
||||
}
|
||||
|
||||
type postgresRowResult struct {
|
||||
columns []string
|
||||
rows pgx.Rows
|
||||
|
||||
Reference in New Issue
Block a user