Files

151 lines
3.4 KiB
Go

package postgres
import (
"context"
"errors"
"fmt"
"git.ksdemosapps.com/kylesoda/pgx-learning/internal/models"
"git.ksdemosapps.com/kylesoda/pgx-learning/internal/repository"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type pgxRepository struct {
db *pgxpool.Pool
}
func NewTaskRepository(pool *pgxpool.Pool) repository.TaskRepository {
return &pgxRepository{
db: pool,
}
}
func (r *pgxRepository) Save(ctx context.Context, task *models.Task) error {
sql := `
INSERT INTO public.tasks (text, completed)
VALUES ($1, $2)
RETURNING id, created_at, updated_at
`
return r.db.QueryRow(ctx, sql, task.Text, task.Completed).Scan(&task.Id, &task.CreatedAt, &task.UpdatedAt)
}
func (r *pgxRepository) GetById(ctx context.Context, id int) (*models.Task, error) {
sql := `
SELECT id, text, completed, created_at, updated_at
FROM public.tasks
WHERE id = $1
`
var task models.Task
err := r.db.QueryRow(ctx, sql, id).Scan(
&task.Id,
&task.Text,
&task.Completed,
&task.CreatedAt,
&task.UpdatedAt,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("Unexpected error when retrieving task: %w", err)
}
return &task, nil
}
func (r *pgxRepository) GetAll(ctx context.Context, filters repository.GetAllTaskFilters) ([]models.Task, error) {
sql := `SELECT id, text, completed, created_at, updated_at FROM public.tasks WHERE 1 = 1`
args := []any{}
if filters.Completed != nil {
args = append(args, *filters.Completed)
sql += fmt.Sprintf(" AND completed = $%d", len(args))
}
limit := 1000
if filters.Limit > 0 {
limit = filters.Limit
}
args = append(args, limit)
sql += fmt.Sprintf(" LIMIT $%d", len(args))
if filters.Offset > 0 {
args = append(args, filters.Offset)
sql += fmt.Sprintf(" OFFSET $%d", len(args))
}
rows, err := r.db.Query(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("Unexpected error when querying tasks: %w", err)
}
defer rows.Close()
tasks, err := pgx.CollectRows(rows, pgx.RowToStructByName[models.Task])
if err != nil {
return nil, fmt.Errorf("Unexpected error when collecting tasks: %w", err)
}
return tasks, nil
}
func (r *pgxRepository) Update(ctx context.Context, id int, input *models.UpdateTaskInput) (*models.Task, error) {
sql := "UPDATE public.tasks SET updated_at = CURRENT_TIMESTAMP"
args := []any{}
if input.Text != nil {
args = append(args, input.Text)
sql += fmt.Sprintf(", text = $%d", len(args))
}
if input.Completed != nil {
args = append(args, input.Completed)
sql += fmt.Sprintf(", completed = $%d", len(args))
}
if len(args) == 0 {
return r.GetById(ctx, id)
}
args = append(args, id)
sql += fmt.Sprintf(" WHERE id = $%d RETURNING id, text, completed, created_at, updated_at", len(args))
var task models.Task
err := r.db.QueryRow(ctx, sql, args...).Scan(
&task.Id,
&task.Text,
&task.Completed,
&task.CreatedAt,
&task.UpdatedAt,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("task not found")
}
return nil, fmt.Errorf("updating task: %w", err)
}
return &task, nil
}
func (r *pgxRepository) Delete(ctx context.Context, id int) error {
sql := "DELETE FROM public.tasks WHERE id = $1"
commandTag, err := r.db.Exec(ctx, sql, id)
if err != nil {
return fmt.Errorf("Unexpected error when deleting task: %w", err)
}
if commandTag.RowsAffected() == 0 {
return fmt.Errorf("No task found with id %d", id)
}
return nil
}