package dbwrapper import ( "context" "errors" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) func init() { Register("postgres", func() DbWrapper { return &postgresDbWrapper{dialect: "postgres"} }) } type postgresRowResult struct { columns []string rows pgx.Rows } func (pr *postgresRowResult) Close() error { pr.rows.Close() return nil } func (pr *postgresRowResult) Columns() ([]string, error) { if pr.columns != nil { return pr.columns, nil } rawColumns := pr.rows.FieldDescriptions() if rawColumns == nil { return nil, errors.New("error retrieving columns") } columns := make([]string, 0, len(rawColumns)) for _, rc := range rawColumns { columns = append(columns, rc.Name) } return columns, nil } func (pr *postgresRowResult) Err() error { return pr.rows.Err() } func (pr *postgresRowResult) Next() bool { return pr.rows.Next() } func (pr *postgresRowResult) Scan(dest ...any) error { return pr.rows.Scan(dest...) } func (pr *postgresRowResult) Values() ([]any, error) { return pr.rows.Values() } type postgresDbWrapper struct { db *pgxpool.Pool dialect string } func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { pool, err := pgxpool.New(ctx, dbUrl) if err != nil { return err } if err := pool.Ping(ctx); err != nil { pool.Close() return err } pw.db = pool return nil } func (pw *postgresDbWrapper) Close() error { pw.db.Close() return nil } func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { result, err := pw.db.Exec(ctx, query, args...) if err != nil { return ExecResult{}, err } return ExecResult{AffectedRows: result.RowsAffected()}, nil } func (pw *postgresDbWrapper) GetDialect() string { return pw.dialect } func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { rows, err := pw.db.Query(ctx, query, args...) if err != nil { return nil, err } return &postgresRowResult{columns: nil, rows: rows}, nil } func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows)) if err != nil { return 0, err } return affectedRows, nil }