129 lines
2.6 KiB
Go
129 lines
2.6 KiB
Go
package dbwrapper
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
func init() {
|
|
Register("postgres", func() DbWrapper {
|
|
return &postgresDbWrapper{dialect: "postgres"}
|
|
})
|
|
}
|
|
|
|
type postgresRowResult struct {
|
|
row pgx.Row
|
|
}
|
|
|
|
func (pr *postgresRowResult) Scan(dest ...any) error {
|
|
return pr.row.Scan(dest...)
|
|
}
|
|
|
|
type postgresRowsResult struct {
|
|
columns []string
|
|
rows pgx.Rows
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Close() error {
|
|
pr.rows.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Columns() ([]string, error) {
|
|
if pr.columns != nil {
|
|
return pr.columns, nil
|
|
}
|
|
|
|
rawColumns := pr.rows.FieldDescriptions()
|
|
if rawColumns == nil {
|
|
return nil, errors.New("error retrieving columns")
|
|
}
|
|
|
|
columns := make([]string, 0, len(rawColumns))
|
|
for _, rc := range rawColumns {
|
|
columns = append(columns, rc.Name)
|
|
}
|
|
|
|
return columns, nil
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Err() error {
|
|
return pr.rows.Err()
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Next() bool {
|
|
return pr.rows.Next()
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Scan(dest ...any) error {
|
|
return pr.rows.Scan(dest...)
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Values() ([]any, error) {
|
|
return pr.rows.Values()
|
|
}
|
|
|
|
type postgresDbWrapper struct {
|
|
db *pgxpool.Pool
|
|
dialect string
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error {
|
|
pool, err := pgxpool.New(ctx, dbUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return err
|
|
}
|
|
|
|
pw.db = pool
|
|
return nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Close() error {
|
|
pw.db.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
|
|
result, err := pw.db.Exec(ctx, query, args...)
|
|
if err != nil {
|
|
return ExecResult{}, err
|
|
}
|
|
|
|
return ExecResult{AffectedRows: result.RowsAffected()}, nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) GetDialect() string {
|
|
return pw.dialect
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
|
|
rows, err := pw.db.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postgresRowsResult{columns: nil, rows: rows}, nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult {
|
|
row := pw.db.QueryRow(ctx, query, args...)
|
|
return &postgresRowResult{row: row}
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
|
|
affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return affectedRows, nil
|
|
}
|