onasty/internal/store/psql/sessionrepo/sessionrepo.go(view raw)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
package sessionrepo
import (
"context"
"errors"
"time"
"github.com/gofrs/uuid/v5"
"github.com/henvic/pgq"
"github.com/jackc/pgx/v5"
"github.com/olexsmir/onasty/internal/models"
"github.com/olexsmir/onasty/internal/store/psqlutil"
)
type SessionStorer interface {
Set(ctx context.Context, usedID uuid.UUID, refreshToken string, expiresAt time.Time) error
GetUserIDByRefreshToken(ctx context.Context, refreshToken string) (uuid.UUID, error)
Update(ctx context.Context, userID uuid.UUID, refreshToken string, newRefreshToken string) error
Delete(ctx context.Context, userID uuid.UUID, refreshToken string) error
DeleteAllByUserID(ctx context.Context, userID uuid.UUID) error
}
var _ SessionStorer = (*SessionRepo)(nil)
type SessionRepo struct {
db *psqlutil.DB
}
func New(db *psqlutil.DB) *SessionRepo {
return &SessionRepo{
db: db,
}
}
func (s *SessionRepo) Set(
ctx context.Context,
userID uuid.UUID,
refreshToken string,
expiresAt time.Time,
) error {
query, args, err := pgq.
Insert("sessions").
Columns("user_id", "refresh_token", "expires_at").
Values(userID, refreshToken, expiresAt).
SQL()
if err != nil {
return err
}
_, err = s.db.Exec(ctx, query, args...)
return err
}
func (s *SessionRepo) Update(
ctx context.Context,
userID uuid.UUID,
refreshToken string,
newRefreshToken string,
) error {
query := `--sql
update sessions
set refresh_token = $1
where
user_id = $2
and refresh_token = $3
-- and expires_at < now()
`
res, err := s.db.Exec(ctx, query, newRefreshToken, userID, refreshToken)
if res.RowsAffected() != 1 {
return models.ErrSessionNotFound
}
return err
}
func (s *SessionRepo) GetUserIDByRefreshToken(
ctx context.Context,
refreshToken string,
) (uuid.UUID, error) {
query, args, err := pgq.
Select("user_id").
From("sessions").
Where(pgq.Eq{"refresh_token": refreshToken}).
SQL()
if err != nil {
return uuid.UUID{}, err
}
var userID uuid.UUID
err = s.db.QueryRow(ctx, query, args...).Scan(&userID)
if errors.Is(err, pgx.ErrNoRows) {
return uuid.UUID{}, models.ErrUserNotFound
}
return userID, err
}
func (s *SessionRepo) Delete(ctx context.Context, userID uuid.UUID, refreshToken string) error {
query := `--sql
DELETE FROM sessions
WHERE user_id = $1
AND refresh_token = $2`
_, err := s.db.Exec(ctx, query, userID, refreshToken)
return err
}
func (s *SessionRepo) DeleteAllByUserID(ctx context.Context, userID uuid.UUID) error {
query := `--sql
delete from sessions
where user_id = $1`
_, err := s.db.Exec(ctx, query, userID)
return err
}
|