151 lines
3.4 KiB
Go
151 lines
3.4 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"git.ksdemosapps.com/kylesoda/pgx-learning/internal/models"
|
|
"git.ksdemosapps.com/kylesoda/pgx-learning/internal/repository"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type pgxRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewTaskRepository(pool *pgxpool.Pool) repository.TaskRepository {
|
|
return &pgxRepository{
|
|
db: pool,
|
|
}
|
|
}
|
|
|
|
func (r *pgxRepository) Save(ctx context.Context, task *models.Task) error {
|
|
sql := `
|
|
INSERT INTO public.tasks (text, completed)
|
|
VALUES ($1, $2)
|
|
RETURNING id, created_at, updated_at
|
|
`
|
|
|
|
return r.db.QueryRow(ctx, sql, task.Text, task.Completed).Scan(&task.Id, &task.CreatedAt, &task.UpdatedAt)
|
|
}
|
|
|
|
func (r *pgxRepository) GetById(ctx context.Context, id int) (*models.Task, error) {
|
|
sql := `
|
|
SELECT id, text, completed, created_at, updated_at
|
|
FROM public.tasks
|
|
WHERE id = $1
|
|
`
|
|
|
|
var task models.Task
|
|
err := r.db.QueryRow(ctx, sql, id).Scan(
|
|
&task.Id,
|
|
&task.Text,
|
|
&task.Completed,
|
|
&task.CreatedAt,
|
|
&task.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("Unexpected error when retrieving task: %w", err)
|
|
}
|
|
|
|
return &task, nil
|
|
}
|
|
|
|
func (r *pgxRepository) GetAll(ctx context.Context, filters repository.GetAllTaskFilters) ([]models.Task, error) {
|
|
sql := `SELECT id, text, completed, created_at, updated_at FROM public.tasks WHERE 1 = 1`
|
|
args := []any{}
|
|
|
|
if filters.Completed != nil {
|
|
args = append(args, *filters.Completed)
|
|
sql += fmt.Sprintf(" AND completed = $%d", len(args))
|
|
}
|
|
|
|
limit := 1000
|
|
if filters.Limit > 0 {
|
|
limit = filters.Limit
|
|
}
|
|
args = append(args, limit)
|
|
sql += fmt.Sprintf(" LIMIT $%d", len(args))
|
|
|
|
if filters.Offset > 0 {
|
|
args = append(args, filters.Offset)
|
|
sql += fmt.Sprintf(" OFFSET $%d", len(args))
|
|
}
|
|
|
|
rows, err := r.db.Query(ctx, sql, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Unexpected error when querying tasks: %w", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
tasks, err := pgx.CollectRows(rows, pgx.RowToStructByName[models.Task])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Unexpected error when collecting tasks: %w", err)
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
func (r *pgxRepository) Update(ctx context.Context, id int, input *models.UpdateTaskInput) (*models.Task, error) {
|
|
sql := "UPDATE public.tasks SET updated_at = CURRENT_TIMESTAMP"
|
|
args := []any{}
|
|
|
|
if input.Text != nil {
|
|
args = append(args, input.Text)
|
|
sql += fmt.Sprintf(", text = $%d", len(args))
|
|
}
|
|
|
|
if input.Completed != nil {
|
|
args = append(args, input.Completed)
|
|
sql += fmt.Sprintf(", completed = $%d", len(args))
|
|
}
|
|
|
|
if len(args) == 0 {
|
|
return r.GetById(ctx, id)
|
|
}
|
|
|
|
args = append(args, id)
|
|
sql += fmt.Sprintf(" WHERE id = $%d RETURNING id, text, completed, created_at, updated_at", len(args))
|
|
|
|
var task models.Task
|
|
err := r.db.QueryRow(ctx, sql, args...).Scan(
|
|
&task.Id,
|
|
&task.Text,
|
|
&task.Completed,
|
|
&task.CreatedAt,
|
|
&task.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, fmt.Errorf("task not found")
|
|
}
|
|
|
|
return nil, fmt.Errorf("updating task: %w", err)
|
|
}
|
|
|
|
return &task, nil
|
|
}
|
|
|
|
func (r *pgxRepository) Delete(ctx context.Context, id int) error {
|
|
sql := "DELETE FROM public.tasks WHERE id = $1"
|
|
|
|
commandTag, err := r.db.Exec(ctx, sql, id)
|
|
if err != nil {
|
|
return fmt.Errorf("Unexpected error when deleting task: %w", err)
|
|
}
|
|
|
|
if commandTag.RowsAffected() == 0 {
|
|
return fmt.Errorf("No task found with id %d", id)
|
|
}
|
|
|
|
return nil
|
|
}
|