all repos

onasty @ aebfa02

a one-time notes service

onasty/internal/store/psql/sessionrepo/sessionrepo.go(view raw)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package sessionrepo

import (
	"context"
	"errors"
	"time"

	"github.com/gofrs/uuid/v5"
	"github.com/henvic/pgq"
	"github.com/jackc/pgx/v5"
	"github.com/olexsmir/onasty/internal/models"
	"github.com/olexsmir/onasty/internal/store/psqlutil"
)

type SessionStorer interface {
	Set(ctx context.Context, usedID uuid.UUID, refreshToken string, expiresAt time.Time) error
	GetUserIDByRefreshToken(ctx context.Context, refreshToken string) (uuid.UUID, error)
	Update(ctx context.Context, userID uuid.UUID, refreshToken string, newRefreshToken string) error
	Delete(ctx context.Context, userID uuid.UUID, refreshToken string) error
	DeleteAllByUserID(ctx context.Context, userID uuid.UUID) error
}

var _ SessionStorer = (*SessionRepo)(nil)

type SessionRepo struct {
	db *psqlutil.DB
}

func New(db *psqlutil.DB) *SessionRepo {
	return &SessionRepo{
		db: db,
	}
}

func (s *SessionRepo) Set(
	ctx context.Context,
	userID uuid.UUID,
	refreshToken string,
	expiresAt time.Time,
) error {
	query, args, err := pgq.
		Insert("sessions").
		Columns("user_id", "refresh_token", "expires_at").
		Values(userID, refreshToken, expiresAt).
		SQL()
	if err != nil {
		return err
	}

	_, err = s.db.Exec(ctx, query, args...)
	return err
}

func (s *SessionRepo) Update(
	ctx context.Context,
	userID uuid.UUID,
	refreshToken string,
	newRefreshToken string,
) error {
	query := `--sql
update sessions
set refresh_token = $1
where
  user_id = $2
  and refresh_token = $3
  -- and expires_at < now()
`

	res, err := s.db.Exec(ctx, query, newRefreshToken, userID, refreshToken)
	if res.RowsAffected() != 1 {
		return models.ErrSessionNotFound
	}

	return err
}

func (s *SessionRepo) GetUserIDByRefreshToken(
	ctx context.Context,
	refreshToken string,
) (uuid.UUID, error) {
	query, args, err := pgq.
		Select("user_id").
		From("sessions").
		Where(pgq.Eq{"refresh_token": refreshToken}).
		SQL()
	if err != nil {
		return uuid.UUID{}, err
	}

	var userID uuid.UUID
	err = s.db.QueryRow(ctx, query, args...).Scan(&userID)
	if errors.Is(err, pgx.ErrNoRows) {
		return uuid.UUID{}, models.ErrUserNotFound
	}

	return userID, err
}

func (s *SessionRepo) Delete(ctx context.Context, userID uuid.UUID, refreshToken string) error {
	query := `--sql
DELETE FROM sessions
WHERE user_id = $1
  AND refresh_token = $2`

	_, err := s.db.Exec(ctx, query, userID, refreshToken)
	return err
}

func (s *SessionRepo) DeleteAllByUserID(ctx context.Context, userID uuid.UUID) error {
	query := `--sql
delete from sessions
where user_id = $1`

	_, err := s.db.Exec(ctx, query, userID)
	return err
}