package dbwrapper import ( "context" "errors" "fmt" "strings" dbdialects "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper/db_dialects" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) func init() { Register(dbdialects.Postgres, func() DbWrapper { return &postgresDbWrapper{dialect: dbdialects.Postgres} }) } type postgresRowResult struct { row pgx.Row } func (pr *postgresRowResult) Scan(dest ...any) error { return pr.row.Scan(dest...) } type postgresRowsResult struct { columns []string rows pgx.Rows } func (pr *postgresRowsResult) Close() error { pr.rows.Close() return nil } func (pr *postgresRowsResult) Columns() ([]string, error) { if pr.columns != nil { return pr.columns, nil } rawColumns := pr.rows.FieldDescriptions() if rawColumns == nil { return nil, errors.New("error retrieving columns") } columns := make([]string, 0, len(rawColumns)) for _, rc := range rawColumns { columns = append(columns, rc.Name) } return columns, nil } func (pr *postgresRowsResult) Err() error { return pr.rows.Err() } func (pr *postgresRowsResult) Next() bool { return pr.rows.Next() } func (pr *postgresRowsResult) Scan(dest ...any) error { return pr.rows.Scan(dest...) } func (pr *postgresRowsResult) Values() ([]any, error) { return pr.rows.Values() } type postgresDbWrapper struct { db *pgxpool.Pool dialect string } func (pw *postgresDbWrapper) Connect(ctx context.Context, dbUrl string) error { pool, err := pgxpool.New(ctx, dbUrl) if err != nil { return err } if err := pool.Ping(ctx); err != nil { pool.Close() return err } pw.db = pool return nil } func (pw *postgresDbWrapper) Close() error { pw.db.Close() return nil } func (pw *postgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (ExecResult, error) { result, err := pw.db.Exec(ctx, query, args...) if err != nil { return ExecResult{}, err } return ExecResult{AffectedRows: result.RowsAffected()}, nil } func (pw *postgresDbWrapper) GetDialect() string { return pw.dialect } func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...any) (RowsResult, error) { rows, err := pw.db.Query(ctx, query, args...) if err != nil { return nil, err } return &postgresRowsResult{columns: nil, rows: rows}, nil } func (pw *postgresDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { row := pw.db.QueryRow(ctx, query, args...) return &postgresRowResult{row: row} } func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { affectedRows, err := pw.db.CopyFrom(ctx, pgx.Identifier{schema, table}, columnNames, pgx.CopyFromRows(rows)) if err != nil { return 0, err } return affectedRows, nil } func (pw *postgresDbWrapper) QueryFromObject(ctx context.Context, q ExtractionQuery) (RowsResult, error) { var sbQuery strings.Builder sbQuery.WriteString("SELECT ") if len(q.Columns) == 0 { sbQuery.WriteString("*") } else { for i, col := range q.Columns { switch col.Type() { case "GEOMETRY": fmt.Fprintf(&sbQuery, `ST_AsEWKB("%s") AS "%s"`, col.Name(), col.Name()) default: fmt.Fprintf(&sbQuery, `"%s"`, col.Name()) } if i < len(q.Columns)-1 { sbQuery.WriteString(", ") } } } fmt.Fprintf(&sbQuery, ` FROM "%s"."%s"`, q.Schema, q.Table) if q.LowerLimit.IsValid || q.UpperLimit.IsValid { sbQuery.WriteString(" WHERE ") paramIdx := 1 if q.LowerLimit.IsValid { fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey) if q.LowerLimit.IsInclusive { sbQuery.WriteString(" >=") } else { sbQuery.WriteString(" >") } fmt.Fprintf(&sbQuery, " $%d", paramIdx) paramIdx++ } if q.LowerLimit.IsValid && q.UpperLimit.IsValid { sbQuery.WriteString(" AND ") } if q.UpperLimit.IsValid { fmt.Fprintf(&sbQuery, `"%s"`, q.PrimaryKey) if q.UpperLimit.IsInclusive { sbQuery.WriteString(" <=") } else { sbQuery.WriteString(" <") } fmt.Fprintf(&sbQuery, " $%d", paramIdx) paramIdx++ } } fmt.Fprintf(&sbQuery, ` ORDER BY "%s" ASC`, q.PrimaryKey) queryString := sbQuery.String() var queryArgs []any if q.LowerLimit.IsValid { queryArgs = append(queryArgs, q.LowerLimit.Value) } if q.UpperLimit.IsValid { queryArgs = append(queryArgs, q.UpperLimit.Value) } return pw.Query(ctx, queryString, queryArgs...) }