all repos

onasty @ dependabot/github_actions/marocchino/tool-versions-action-2

a one-time notes service

onasty/internal/store/psql/noterepo/noterepo.go (view raw)

Oleksandr Smirnov Oleksandr Smirnov
olexsmir@gmail.com
refactor!: rename "burn before expiration" to "keep before expiration" (#199)..., 9 months ago
1
package noterepo
2
3
import (
4
	"context"
5
	"database/sql"
6
	"errors"
7
	"time"
8
9
	"github.com/gofrs/uuid/v5"
10
	"github.com/henvic/pgq"
11
	"github.com/jackc/pgx/v5"
12
	"github.com/olexsmir/onasty/internal/dtos"
13
	"github.com/olexsmir/onasty/internal/models"
14
	"github.com/olexsmir/onasty/internal/store/psqlutil"
15
)
16
17
type NoteStorer interface {
18
	// Create creates a note.
19
	Create(ctx context.Context, note models.Note) error
20
21
	// GetBySlug gets a note by slug.
22
	// Returns [models.ErrNoteNotFound] if note is not found.
23
	GetBySlug(ctx context.Context, slug dtos.NoteSlug) (models.Note, error)
24
25
	// GetMetadataBySlug gets note's metadata by its slug.
26
	// Returns [models.ErrNoteNotFound] if note is not found OR read.
27
	GetMetadataBySlug(ctx context.Context, slug dtos.NoteSlug) (dtos.NoteMetadata, error)
28
29
	// GetAllByAuthorID returns all notes with specified author.
30
	GetAllByAuthorID(ctx context.Context, authorID uuid.UUID) ([]models.Note, error)
31
32
	// GetAllReadByAuthorID returns all notes that are read and authored by specified author.
33
	GetAllReadByAuthorID(ctx context.Context, authorID uuid.UUID) ([]models.Note, error)
34
35
	// GetAllUnreadByAuthorID returns all notes that are unread and authored by specified author.
36
	GetAllUnreadByAuthorID(ctx context.Context, authorID uuid.UUID) ([]models.Note, error)
37
38
	// GetCountOfNotesByAuthorID returns count of notes created by specified author.
39
	GetCountOfNotesByAuthorID(ctx context.Context, authorID uuid.UUID) (int64, error)
40
41
	// GetBySlugAndPassword gets a note by slug and password.
42
	// the "password" should be hashed.
43
	//
44
	// Returns [models.ErrNoteNotFound] if note is not found.
45
	GetBySlugAndPassword(
46
		ctx context.Context,
47
		slug dtos.NoteSlug,
48
		password string,
49
	) (models.Note, error)
50
51
	// UpdateExpirationTimeSettingsBySlug patches note by updating expiresAt and keepBeforeExpiration if one is passwd
52
	// Returns [models.ErrNoteNotFound] if note is not found.
53
	UpdateExpirationTimeSettingsBySlug(
54
		ctx context.Context,
55
		slug dtos.NoteSlug,
56
		patch dtos.PatchNote,
57
		authorID uuid.UUID,
58
	) error
59
60
	// RemoveBySlug marks note as read, deletes it's content, and keeps meta data
61
	// Returns [models.ErrNoteNotFound] if note is not found.
62
	RemoveBySlug(ctx context.Context, slug dtos.NoteSlug, readAt time.Time) error
63
64
	// DeleteNoteBySlug deletes(unlike [RemoveBySlug]) note by slug.
65
	// Returns [models.ErrNoteNotFound] if note is not found.
66
	DeleteNoteBySlug(ctx context.Context, slug dtos.NoteSlug, authorID uuid.UUID) error
67
68
	// SetAuthorIDBySlug assigns author to note by slug.
69
	// Returns [models.ErrNoteNotFound] if note is not found.
70
	SetAuthorIDBySlug(ctx context.Context, slug dtos.NoteSlug, authorID uuid.UUID) error
71
72
	// UpdatePasswordBySlug updates or sets password on a note.
73
	UpdatePasswordBySlug(
74
		ctx context.Context,
75
		slug dtos.NoteSlug,
76
		authorID uuid.UUID,
77
		passwd string,
78
	) error
79
}
80
81
var _ NoteStorer = (*NoteRepo)(nil)
82
83
type NoteRepo struct {
84
	db *psqlutil.DB
85
}
86
87
func New(db *psqlutil.DB) *NoteRepo {
88
	return &NoteRepo{db}
89
}
90
91
func (s *NoteRepo) Create(ctx context.Context, inp models.Note) error {
92
	query, args, err := pgq.
93
		Insert("notes").
94
		Columns("content", "slug", "password", "keep_before_expiration", "created_at", "expires_at").
95
		Values(inp.Content, inp.Slug, inp.Password, inp.KeepBeforeExpiration, inp.CreatedAt, inp.ExpiresAt).
96
		SQL()
97
	if err != nil {
98
		return err
99
	}
100
101
	_, err = s.db.Exec(ctx, query, args...)
102
	if psqlutil.IsDuplicateErr(err, "notes_slug_key") {
103
		return models.ErrNoteSlugIsAlreadyInUse
104
	}
105
106
	return err
107
}
108
109
func (s *NoteRepo) GetBySlug(ctx context.Context, slug dtos.NoteSlug) (models.Note, error) {
110
	query, args, err := pgq.
111
		Select("content", "slug", "keep_before_expiration", "read_at", "created_at", "expires_at").
112
		From("notes").
113
		Where("(password is null or password = '')").
114
		Where(pgq.Eq{"slug": slug}).
115
		SQL()
116
	if err != nil {
117
		return models.Note{}, err
118
	}
119
120
	var note models.Note
121
	var readAt sql.NullTime
122
	err = s.db.QueryRow(ctx, query, args...).
123
		Scan(&note.Content, &note.Slug, &note.KeepBeforeExpiration, &readAt, &note.CreatedAt, &note.ExpiresAt)
124
	if errors.Is(err, pgx.ErrNoRows) {
125
		return models.Note{}, models.ErrNoteNotFound
126
	}
127
128
	note.ReadAt = psqlutil.NullTimeToTime(readAt)
129
130
	return note, err
131
}
132
133
func (s *NoteRepo) GetMetadataBySlug(
134
	ctx context.Context,
135
	slug dtos.NoteSlug,
136
) (dtos.NoteMetadata, error) {
137
	query := `--sql
138
select n.created_at, (n.password is not null and n.password <> '') has_password, n.read_at
139
from notes n
140
where slug = $1`
141
142
	var readAt sql.NullTime
143
	var metadata dtos.NoteMetadata
144
	err := s.db.QueryRow(ctx, query, slug).Scan(&metadata.CreatedAt, &metadata.HasPassword, &readAt)
145
	if errors.Is(err, pgx.ErrNoRows) {
146
		return dtos.NoteMetadata{}, models.ErrNoteNotFound
147
	}
148
149
	if !psqlutil.NullTimeToTime(readAt).IsZero() {
150
		return dtos.NoteMetadata{}, models.ErrNoteNotFound
151
	}
152
153
	return metadata, err
154
}
155
156
func (s *NoteRepo) GetAllByAuthorID(
157
	ctx context.Context,
158
	authorID uuid.UUID,
159
) ([]models.Note, error) {
160
	query := `--sql
161
select n.content, n.slug, n.keep_before_expiration, n.password, n.read_at, n.created_at, n.expires_at
162
from notes n
163
inner join notes_authors na on n.id = na.note_id
164
where na.user_id = $1`
165
166
	return s.getAllNotes(ctx, query, authorID)
167
}
168
169
func (s *NoteRepo) GetAllReadByAuthorID(
170
	ctx context.Context,
171
	authorID uuid.UUID,
172
) ([]models.Note, error) {
173
	query := `--sql
174
select n.content, n.slug, n.keep_before_expiration, n.password, n.read_at, n.created_at, n.expires_at
175
from notes n
176
inner join notes_authors na on n.id = na.note_id
177
where na.user_id = $1
178
	and n.read_at is not null`
179
180
	return s.getAllNotes(ctx, query, authorID)
181
}
182
183
func (s *NoteRepo) GetAllUnreadByAuthorID(
184
	ctx context.Context,
185
	authorID uuid.UUID,
186
) ([]models.Note, error) {
187
	query := `--sql
188
select n.content, n.slug, n.keep_before_expiration, n.password, n.read_at, n.created_at, n.expires_at
189
from notes n
190
inner join notes_authors na on n.id = na.note_id
191
where na.user_id = $1
192
	and n.read_at is null`
193
194
	return s.getAllNotes(ctx, query, authorID)
195
}
196
197
func (s *NoteRepo) GetCountOfNotesByAuthorID(
198
	ctx context.Context,
199
	authorID uuid.UUID,
200
) (int64, error) {
201
	var count int64
202
	err := s.db.QueryRow(
203
		ctx,
204
		`select count(*) from notes_authors where user_id = $1`,
205
		authorID.String(),
206
	).Scan(&count)
207
208
	return count, err
209
}
210
211
func (s *NoteRepo) GetBySlugAndPassword(
212
	ctx context.Context,
213
	slug dtos.NoteSlug,
214
	passwd string,
215
) (models.Note, error) {
216
	query, args, err := pgq.
217
		Select("content", "slug", "keep_before_expiration", "read_at", "created_at", "expires_at").
218
		From("notes").
219
		Where(pgq.Eq{
220
			"slug":     slug,
221
			"password": passwd,
222
		}).
223
		SQL()
224
	if err != nil {
225
		return models.Note{}, err
226
	}
227
228
	var note models.Note
229
	var readAt sql.NullTime
230
	err = s.db.QueryRow(ctx, query, args...).
231
		Scan(&note.Content, &note.Slug, &note.KeepBeforeExpiration, &readAt, &note.CreatedAt, &note.ExpiresAt)
232
233
	if errors.Is(err, pgx.ErrNoRows) {
234
		return models.Note{}, models.ErrNoteNotFound
235
	}
236
237
	note.ReadAt = psqlutil.NullTimeToTime(readAt)
238
239
	return note, err
240
}
241
242
func (s *NoteRepo) UpdateExpirationTimeSettingsBySlug(
243
	ctx context.Context,
244
	slug dtos.NoteSlug,
245
	patch dtos.PatchNote,
246
	authorID uuid.UUID,
247
) error {
248
	query := `--sql
249
update notes n
250
set keep_before_expiration = COALESCE($1, n.keep_before_expiration),
251
    expires_at = COALESCE($2, n.expires_at)
252
from notes_authors na
253
where n.slug = $3
254
  and na.user_id = $4
255
  and na.note_id = n.id`
256
257
	ct, err := s.db.Exec(ctx, query,
258
		patch.KeepBeforeExpiration, patch.ExpiresAt,
259
		slug, authorID.String())
260
	if err != nil {
261
		return err
262
	}
263
264
	if ct.RowsAffected() == 0 {
265
		return models.ErrNoteNotFound
266
	}
267
268
	return nil
269
}
270
271
func (s *NoteRepo) RemoveBySlug(
272
	ctx context.Context,
273
	slug dtos.NoteSlug,
274
	readAt time.Time,
275
) error {
276
	query, args, err := pgq.
277
		Update("notes").
278
		Set("content", "").
279
		Set("read_at", readAt).
280
		Where(pgq.Eq{"slug": slug}).
281
		Where("read_at is null").
282
		SQL()
283
	if err != nil {
284
		return err
285
	}
286
287
	_, err = s.db.Exec(ctx, query, args...)
288
	if errors.Is(err, pgx.ErrNoRows) {
289
		return models.ErrNoteNotFound
290
	}
291
292
	return err
293
}
294
295
func (s *NoteRepo) DeleteNoteBySlug(
296
	ctx context.Context,
297
	slug dtos.NoteSlug,
298
	authorID uuid.UUID,
299
) error {
300
	query := `--sql
301
delete from notes n
302
using notes_authors na
303
where n.slug = $1
304
  and na.user_id = $2`
305
306
	ct, err := s.db.Exec(ctx, query, slug, authorID.String())
307
	if err != nil {
308
		return err
309
	}
310
311
	if ct.RowsAffected() == 0 {
312
		return models.ErrNoteNotFound
313
	}
314
315
	return nil
316
}
317
318
func (s *NoteRepo) SetAuthorIDBySlug(
319
	ctx context.Context,
320
	slug dtos.NoteSlug,
321
	authorID uuid.UUID,
322
) error {
323
	tx, err := s.db.Begin(ctx)
324
	if err != nil {
325
		return err
326
	}
327
	defer tx.Rollback(ctx) //nolint:errcheck
328
329
	var noteID uuid.UUID
330
	err = tx.QueryRow(ctx, "select id from notes where slug = $1", slug).Scan(&noteID)
331
	if err != nil {
332
		if errors.Is(err, pgx.ErrNoRows) {
333
			return models.ErrNoteNotFound
334
		}
335
		return err
336
	}
337
338
	_, err = tx.Exec(
339
		ctx,
340
		"insert into notes_authors (note_id, user_id) values ($1, $2)",
341
		noteID, authorID,
342
	)
343
	if err != nil {
344
		return err
345
	}
346
347
	return tx.Commit(ctx)
348
}
349
350
func (s *NoteRepo) UpdatePasswordBySlug(
351
	ctx context.Context,
352
	slug dtos.NoteSlug,
353
	authorID uuid.UUID,
354
	passwd string,
355
) error {
356
	query := `--sql
357
update notes n
358
set password = $1
359
from notes_authors na
360
where n.slug = $2
361
  and na.user_id = $3
362
  and na.note_id = n.id`
363
364
	ct, err := s.db.Exec(ctx, query, passwd, slug, authorID.String())
365
	if err != nil {
366
		return err
367
	}
368
369
	if ct.RowsAffected() == 0 {
370
		return models.ErrNoteNotFound
371
	}
372
373
	return nil
374
}
375
376
// getAllNotes is a helper function for [NoteRepo.GetAllByAuthorID], [NoteRepo.GetAllReadByAuthorID],
377
// and [NoteRepo.GetAllUnreadByAuthorID].
378
// The query's SELECT elements order should be consistent across all function calls.
379
func (s *NoteRepo) getAllNotes(
380
	ctx context.Context,
381
	query string,
382
	authorID uuid.UUID,
383
) ([]models.Note, error) {
384
	rows, err := s.db.Query(ctx, query, authorID.String())
385
	if err != nil {
386
		return nil, err
387
	}
388
389
	defer rows.Close()
390
391
	var notes []models.Note
392
	for rows.Next() {
393
		var note models.Note
394
		var readAt sql.NullTime
395
		if err := rows.Scan(&note.Content, &note.Slug, &note.KeepBeforeExpiration, &note.Password,
396
			&readAt, &note.CreatedAt, &note.ExpiresAt); err != nil {
397
			return nil, err
398
		}
399
400
		note.ReadAt = psqlutil.NullTimeToTime(readAt)
401
		notes = append(notes, note)
402
	}
403
404
	return notes, rows.Err()
405
}