package postgres import ( "context" "errors" "fmt" "git.ksdemosapps.com/kylesoda/pgx-learning/internal/models" "git.ksdemosapps.com/kylesoda/pgx-learning/internal/repository" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type pgxRepository struct { db *pgxpool.Pool } func NewTaskRepository(pool *pgxpool.Pool) repository.TaskRepository { return &pgxRepository{ db: pool, } } func (r *pgxRepository) Save(ctx context.Context, task *models.Task) error { sql := ` INSERT INTO public.tasks (text, completed) VALUES ($1, $2) RETURNING id, created_at, updated_at ` return r.db.QueryRow(ctx, sql, task.Text, task.Completed).Scan(&task.Id, &task.CreatedAt, &task.UpdatedAt) } func (r *pgxRepository) GetById(ctx context.Context, id int) (*models.Task, error) { sql := ` SELECT id, text, completed, created_at, updated_at FROM public.tasks WHERE id = $1 ` var task models.Task err := r.db.QueryRow(ctx, sql, id).Scan( &task.Id, &task.Text, &task.Completed, &task.CreatedAt, &task.UpdatedAt, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("Unexpected error when retrieving task: %w", err) } return &task, nil } func (r *pgxRepository) GetAll(ctx context.Context, filters repository.GetAllTaskFilters) ([]models.Task, error) { sql := `SELECT id, text, completed, created_at, updated_at FROM public.tasks WHERE 1 = 1` args := []any{} if filters.Completed != nil { args = append(args, *filters.Completed) sql += fmt.Sprintf(" AND completed = $%d", len(args)) } limit := 1000 if filters.Limit > 0 { limit = filters.Limit } args = append(args, limit) sql += fmt.Sprintf(" LIMIT $%d", len(args)) if filters.Offset > 0 { args = append(args, filters.Offset) sql += fmt.Sprintf(" OFFSET $%d", len(args)) } rows, err := r.db.Query(ctx, sql, args...) if err != nil { return nil, fmt.Errorf("Unexpected error when querying tasks: %w", err) } defer rows.Close() tasks, err := pgx.CollectRows(rows, pgx.RowToStructByName[models.Task]) if err != nil { return nil, fmt.Errorf("Unexpected error when collecting tasks: %w", err) } return tasks, nil } func (r *pgxRepository) Update(ctx context.Context, id int, input *models.UpdateTaskInput) (*models.Task, error) { sql := "UPDATE public.tasks SET updated_at = CURRENT_TIMESTAMP" args := []any{} if input.Text != nil { args = append(args, input.Text) sql += fmt.Sprintf(", text = $%d", len(args)) } if input.Completed != nil { args = append(args, input.Completed) sql += fmt.Sprintf(", completed = $%d", len(args)) } if len(args) == 0 { return r.GetById(ctx, id) } args = append(args, id) sql += fmt.Sprintf(" WHERE id = $%d RETURNING id, text, completed, created_at, updated_at", len(args)) var task models.Task err := r.db.QueryRow(ctx, sql, args...).Scan( &task.Id, &task.Text, &task.Completed, &task.CreatedAt, &task.UpdatedAt, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, fmt.Errorf("task not found") } return nil, fmt.Errorf("updating task: %w", err) } return &task, nil } func (r *pgxRepository) Delete(ctx context.Context, id int) error { sql := "DELETE FROM public.tasks WHERE id = $1" commandTag, err := r.db.Exec(ctx, sql, id) if err != nil { return fmt.Errorf("Unexpected error when deleting task: %w", err) } if commandTag.RowsAffected() == 0 { return fmt.Errorf("No task found with id %d", id) } return nil }