204 lines
4.3 KiB
Go
204 lines
4.3 KiB
Go
package dbwrapper
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db-dialects"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
func init() {
|
|
Register(dbdialects.Postgres, func() DbWrapper {
|
|
return &postgresDbWrapper{dialect: dbdialects.Postgres}
|
|
})
|
|
}
|
|
|
|
type postgresRowResult struct {
|
|
row pgx.Row
|
|
}
|
|
|
|
func (pr *postgresRowResult) Scan(dest ...any) error {
|
|
return pr.row.Scan(dest...)
|
|
}
|
|
|
|
type postgresRowsResult struct {
|
|
columns []string
|
|
rows pgx.Rows
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Close() error {
|
|
pr.rows.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Columns() ([]string, error) {
|
|
if pr.columns != nil {
|
|
return pr.columns, nil
|
|
}
|
|
|
|
rawColumns := pr.rows.FieldDescriptions()
|
|
if rawColumns == nil {
|
|
return nil, errors.New("error retrieving columns")
|
|
}
|
|
|
|
columns := make([]string, 0, len(rawColumns))
|
|
for _, rc := range rawColumns {
|
|
columns = append(columns, rc.Name)
|
|
}
|
|
|
|
return columns, nil
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Err() error {
|
|
return pr.rows.Err()
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Next() bool {
|
|
return pr.rows.Next()
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Scan(dest ...any) error {
|
|
return pr.rows.Scan(dest...)
|
|
}
|
|
|
|
func (pr *postgresRowsResult) Values() ([]any, error) {
|
|
return pr.rows.Values()
|
|
}
|
|
|
|
type postgresDbWrapper struct {
|
|
db *pgxpool.Pool
|
|
dialect string
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error {
|
|
pool, err := pgxpool.New(ctx, dbUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return err
|
|
}
|
|
|
|
pw.db = pool
|
|
return nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Close() error {
|
|
pw.db.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) {
|
|
result, err := pw.db.Exec(ctx, query, args...)
|
|
if err != nil {
|
|
return ExecResult{}, err
|
|
}
|
|
|
|
return ExecResult{AffectedRows: result.RowsAffected()}, nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) GetDialect() string {
|
|
return pw.dialect
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) {
|
|
rows, err := pw.db.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postgresRowsResult{columns: nil, rows: rows}, nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult {
|
|
row := pw.db.QueryRow(ctx, query, args...)
|
|
return &postgresRowResult{row: row}
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) {
|
|
affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return affectedRows, nil
|
|
}
|
|
|
|
func (pw *postgresDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) {
|
|
var sbQuery strings.Builder
|
|
|
|
sbQuery.WriteString("SELECT ")
|
|
|
|
if len(q.columns) == 0 {
|
|
sbQuery.WriteString("*")
|
|
} else {
|
|
for i, col := range q.columns {
|
|
switch col.Type() {
|
|
case "GEOMETRY":
|
|
fmt.Fprintf(&sbQuery, `ST_AsEWKB("%s") AS "%s"`, col.Name(), col.Name())
|
|
default:
|
|
fmt.Fprintf(&sbQuery, `"%s"`, col.Name())
|
|
}
|
|
|
|
if i < len(q.columns)-1 {
|
|
sbQuery.WriteString(", ")
|
|
}
|
|
}
|
|
}
|
|
|
|
fmt.Fprintf(&sbQuery, ` FROM "%s"."%s"`, q.Schema, q.Table)
|
|
|
|
if q.LowerLimit.IsValid || q.UpperLimit.IsValid {
|
|
sbQuery.WriteString(" WHERE ")
|
|
paramIdx := 1
|
|
|
|
if q.LowerLimit.IsValid {
|
|
fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey)
|
|
if q.LowerLimit.IsInclusive {
|
|
sbQuery.WriteString(" >=")
|
|
} else {
|
|
sbQuery.WriteString(" >")
|
|
}
|
|
fmt.Fprintf(&sbQuery, " $%d", paramIdx)
|
|
paramIdx++
|
|
}
|
|
|
|
if q.LowerLimit.IsValid && q.UpperLimit.IsValid {
|
|
sbQuery.WriteString(" AND ")
|
|
}
|
|
|
|
if q.UpperLimit.IsValid {
|
|
fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey)
|
|
if q.UpperLimit.IsInclusive {
|
|
sbQuery.WriteString(" <=")
|
|
} else {
|
|
sbQuery.WriteString(" <")
|
|
}
|
|
fmt.Fprintf(&sbQuery, " $%d", paramIdx)
|
|
paramIdx++
|
|
}
|
|
}
|
|
|
|
fmt.Fprintf(&sbQuery, ` ORDER BY "%s" ASC`, q.PrimaryKey)
|
|
|
|
queryString := sbQuery.String()
|
|
|
|
var queryArgs []any
|
|
|
|
if q.LowerLimit.IsValid {
|
|
queryArgs = append(queryArgs, q.LowerLimit.Value)
|
|
}
|
|
|
|
if q.UpperLimit.IsValid {
|
|
queryArgs = append(queryArgs, q.UpperLimit.Value)
|
|
}
|
|
|
|
return pw.Query(ctx, queryString, queryArgs...)
|
|
}
|