all repos

onasty @ 040e38372994521cfeb76c57c76998be53bf6a17

a one-time notes service
36 files changed, 561 insertions(+), 301 deletions(-)
refactor: fix annoyances (#97)

* refactor(user): use model where needed instead of dto

* refactor(hasher): add .Compare method

* refactor: get user email instead of credentials, and check password hash manually

* refactor(usersrv): renaming, correct caching

* test: fix auth test

* fix(models): fix typo

* fix(models): return errors correctly

* refactor(dtos): correct naming

* chore(migrations): set default value for read_at

* refactor(notes): use models in internals instead of dtos

* refactor(dtos): delete unused struct

* refactor(dtos): reorganize how it's organized in files

* refactor(dtos): dont duplicate names in package and struct names

* fixup! refactor(dtos): dont duplicate names in package and struct names

* refactor(noterepo): write model and not dto

* refactor(userrepo): naming

* refactor(mailermq): idk why i added that variable in the first place

* docs(e2e): update doc

* refactor(e2e): naming and remove code duplication

* refactor(e2e): renaming

* fix(noterepo): remove whitespace from name of a field

* refactor(e2e): fix typos, add docs

* refactor(events): i was really into interface implementation checking

* chore(usersrv): fix formatting

* test(ratelimit): add tests

* refactor(ratelimit): keep naming consistent, and update comments

* refactor(e2e): fix typo in file name

* chore(jwtutil): update comments

* test(jwtutil): add tests

* fixup! chore(jwtutil): update comments

* test(hasher): test sha256 implementation

* test(e2e): test ping endpoint

* refactor(httpserver): add http server config

* chore(env): update example

* fix(mailer): update to new httpserver api

* fix(config): fix naming

* fix(e2e): actually apply defaults

* test(jwtutil): refactor

* fix(config): fix typos

* fix(metrics): change to the correct handler
Author: Smirnov Oleksandr ss2316544@gmail.com
Committed by: GitHub noreply@github.com
Committed at: 2025-04-22 00:54:46 +0300
Parent: c2e1526
M .env.example

@@ -1,8 +1,9 @@

APP_ENV=debug APP_URL=http://localhost:8000 -SERVER_PORT=8000 PASSWORD_SALT=onasty NOTE_PASSWORD_SALT=secret + +HTTP_PORT=8000 METRICS_ENABLED=true METRICS_PORT=8001
M cmd/server/main.go

@@ -76,7 +76,7 @@ return err

} userPasswordHasher := hasher.NewSHA256Hasher(cfg.PasswordSalt) - notePasswordHasher := hasher.NewSHA256Hasher(cfg.NotePassowrdSalt) + notePasswordHasher := hasher.NewSHA256Hasher(cfg.NotePasswordSalt) jwtTokenizer := jwtutil.NewJWTUtil(cfg.JwtSigningKey, cfg.JwtAccessTokenTTL) mailermq := mailermq.New(nc)

@@ -115,9 +115,9 @@ rateLimiterConfig,

) // http server - srv := httpserver.NewServer(cfg.ServerPort, handler.Handler()) + srv := httpserver.NewServer(handler.Handler(), httpConfig(cfg.HTTPPort, cfg)) go func() { - slog.Info("starting http server", "port", cfg.ServerPort) + slog.Info("starting http server", "port", cfg.HTTPPort) if err := srv.Start(); !errors.Is(err, http.ErrServerClosed) { slog.Error("failed to start http server", "error", err) }

@@ -125,7 +125,7 @@ }()

// metrics if cfg.MetricsEnabled { - mSrv := httpserver.NewServer(cfg.MetricsPort, metrics.Handler()) + mSrv := httpserver.NewServer(metrics.Handler(), httpConfig(cfg.MetricsPort, cfg)) go func() { slog.Info("starting metrics server", "port", cfg.MetricsPort) if err := mSrv.Start(); !errors.Is(err, http.ErrServerClosed) {

@@ -153,3 +153,12 @@ }

return nil } + +func httpConfig(port string, cfg *config.Config) httpserver.Config { + return httpserver.Config{ + Port: port, + ReadTimeout: cfg.HTTPReadTimeout, + WriteTimeout: cfg.HTTPWriteTimeout, + MaxHeaderSizeMb: cfg.HTTPHeaderMaxSizeMb, + } +}
A e2e/api_test.go

@@ -0,0 +1,17 @@

+package e2e_test + +import "net/http" + +type apiPingResponse struct { + Message string `json:"message"` +} + +func (e *AppTestSuite) TestPing() { + httpResp := e.httpRequest(http.MethodGet, "/api/ping", nil) + + var body apiPingResponse + e.readBodyAndUnjsonify(httpResp.Body, &body) + + e.Equal(http.StatusOK, httpResp.Code) + e.Equal(body.Message, "pong") +}
M e2e/apiv1_auth_test.go

@@ -28,7 +28,7 @@ Password: password,

}), ) - dbUser := e.getUserFromDBByUsername(username) + dbUser := e.getUserByUsername(username) hashedPasswd, err := e.hasher.Hash(password) e.require.NoError(err)

@@ -100,12 +100,12 @@ )

e.Equal(http.StatusCreated, httpResp.Code) - user := e.getLastInsertedUserByEmail(email) + user := e.getLastUserByEmail(email) token := e.getVerificationTokenByUserID(user.ID) httpResp = e.httpRequest(http.MethodGet, "/api/v1/auth/verify/"+token.Token, nil) e.Equal(http.StatusOK, httpResp.Code) - user = e.getLastInsertedUserByEmail(email) + user = e.getLastUserByEmail(email) e.Equal(user.Activated, true) }

@@ -140,7 +140,7 @@ }

func (e *AppTestSuite) TestAuthV1_ResendVerificationEmail_wrong() { email, password := e.uuid()+"@"+e.uuid()+".com", "password" - e.insertUserIntoDB(e.uuid(), email, password, true) + e.insertUser(e.uuid(), email, password, true) tests := []struct { name string

@@ -173,8 +173,7 @@ }))

e.Equal(httpResp.Code, t.expectedCode) - // no email should be sent - // e.Empty(e.mailer.GetLastSentEmailToEmail(t.email)) + // TODO: no email should be sent } }

@@ -182,7 +181,7 @@ func (e *AppTestSuite) TestAuthV1_SignIn() {

email := e.uuid() + "email@email.com" password := "qwerty" - uid := e.insertUserIntoDB("test", email, password, true) + uid := e.insertUser("test", email, password, true) httpResp := e.httpRequest( http.MethodPost,

@@ -196,7 +195,7 @@

var body apiv1AuthSignInResponse e.readBodyAndUnjsonify(httpResp.Body, &body) - session := e.getLastUserSessionByUserID(uid) + session := e.getLastSessionByUserID(uid) parsedToken := e.parseJwtToken(body.AccessToken) e.Equal(http.StatusOK, httpResp.Code)

@@ -207,10 +206,10 @@

func (e *AppTestSuite) TestAuthV1_SignIn_wrong() { password := "password" email := e.uuid() + "@test.com" - e.insertUserIntoDB(e.uuid(), email, "password", true) + e.insertUser(e.uuid(), email, "password", true) unactivatedEmail := e.uuid() + "@test.com" - e.insertUserIntoDB(e.uuid(), unactivatedEmail, password, false) + e.insertUser(e.uuid(), unactivatedEmail, password, false) //exhaustruct:ignore tests := []struct {

@@ -223,7 +222,7 @@ expectMsg bool

expectedMsg string }{ { - name: "unactivated user", + name: "inactivated user", email: unactivatedEmail, password: password, expectedCode: http.StatusBadRequest,

@@ -234,7 +233,7 @@ {

name: "wrong email", email: "wrong@email.com", password: password, - expectedCode: http.StatusUnauthorized, + expectedCode: http.StatusBadRequest, }, { name: "wrong password",

@@ -282,7 +281,7 @@

var body apiv1AuthSignInResponse e.readBodyAndUnjsonify(httpResp.Body, &body) - sessionDB := e.getLastUserSessionByUserID(uid) + sessionDB := e.getLastSessionByUserID(uid) e.Equal(e.parseJwtToken(body.AccessToken).UserID, uid.String()) e.Equal(httpResp.Code, http.StatusOK)

@@ -307,13 +306,13 @@

func (e *AppTestSuite) TestAuthV1_Logout() { uid, toks := e.createAndSingIn(e.uuid()+"@test.com", e.uuid(), "password") - sessionDB := e.getLastUserSessionByUserID(uid) + sessionDB := e.getLastSessionByUserID(uid) e.NotEmpty(sessionDB.RefreshToken) httpResp := e.httpRequest(http.MethodPost, "/api/v1/auth/logout", nil, toks.AccessToken) e.Equal(httpResp.Code, http.StatusNoContent) - sessionDB = e.getLastUserSessionByUserID(uid) + sessionDB = e.getLastSessionByUserID(uid) e.Empty(sessionDB.RefreshToken) }

@@ -340,17 +339,17 @@ )

e.Equal(httpResp.Code, http.StatusOK) - userDB := e.getUserFromDBByUsername(username) - hashedNewPassword, err := e.hasher.Hash(newPassword) - e.require.NoError(err) - - e.Equal(userDB.Password, hashedNewPassword) + userDB := e.getUserByUsername(username) + e.Equal(userDB.Username, username) + e.NoError(e.hasher.Compare(userDB.Password, newPassword)) } +// createAndSingIn creates an activated username, logs them in, +// and returns their userID along with access and refresh tokens. func (e *AppTestSuite) createAndSingIn( email, username, password string, ) (uuid.UUID, apiv1AuthSignInResponse) { - uid := e.insertUserIntoDB(username, email, password, true) + uid := e.insertUser(username, email, password, true) httpResp := e.httpRequest( http.MethodPost, "/api/v1/auth/signin",
M e2e/apiv1_notes_authorized_test.goe2e/apiv1_notes_authorized_test.go

@@ -16,7 +16,7 @@

var body apiv1NoteCreateResponse e.readBodyAndUnjsonify(httpResp.Body, &body) - dbNote := e.getNoteFromDBbySlug(body.Slug) + dbNote := e.getNoteBySlug(body.Slug) dbNoteAuthor := e.getLastNoteAuthorsRecordByAuthorID(uid) e.Equal(http.StatusCreated, httpResp.Code)
M e2e/apiv1_notes_test.go

@@ -46,7 +46,7 @@

_, err := uuid.FromString(body.Slug) e.require.NoError(err) - dbNote := e.getNoteFromDBbySlug(body.Slug) + dbNote := e.getNoteBySlug(body.Slug) e.NotEmpty(dbNote) }, },

@@ -62,7 +62,7 @@

var body apiv1NoteCreateResponse e.readBodyAndUnjsonify(r.Body, &body) - dbNote := e.getNoteFromDBbySlug(inp.Slug) + dbNote := e.getNoteBySlug(inp.Slug) e.NotEmpty(dbNote) }, },

@@ -89,7 +89,7 @@

var body apiv1NoteCreateResponse e.readBodyAndUnjsonify(r.Body, &body) - dbNote := e.getNoteFromDBbySlug(body.Slug) + dbNote := e.getNoteBySlug(body.Slug) e.NotEmpty(dbNote) e.Equal(dbNote.Content, inp.Content)

@@ -134,7 +134,7 @@ e.readBodyAndUnjsonify(httpResp.Body, &body)

e.Equal(content, body.Content) - dbNote := e.getNoteFromDBbySlug(bodyCreated.Slug) + dbNote := e.getNoteBySlug(bodyCreated.Slug) e.Equal(dbNote.Content, "") e.Equal(dbNote.ReadAt.IsZero(), false) }

@@ -173,7 +173,7 @@ e.readBodyAndUnjsonify(httpResp.Body, &body)

e.Equal(content, body.Content) - dbNote := e.getNoteFromDBbySlug(bodyCreated.Slug) + dbNote := e.getNoteBySlug(bodyCreated.Slug) e.Equal(dbNote.Content, "") e.Equal(dbNote.ReadAt.IsZero(), false) }
M e2e/e2e_test.go

@@ -5,7 +5,6 @@ "context"

"fmt" "log/slog" "net/http" - "os" "testing" "time"

@@ -199,18 +198,14 @@ return redis, stop

} func (e *AppTestSuite) getConfig() *config.Config { - return &config.Config{ //nolint:exhaustruct - AppEnv: "testing", - AppURL: "", - ServerPort: "3000", - PasswordSalt: "salty-password", - JwtSigningKey: "jwt-key", - JwtAccessTokenTTL: time.Hour, - JwtRefreshTokenTTL: 24 * time.Hour, - VerificationTokenTTL: 24 * time.Hour, - LogShowLine: os.Getenv("LOG_SHOW_LINE") == "true", - LogFormat: "text", - LogLevel: "debug", - CacheUsersTTL: time.Second, - } + e.T().Setenv("APP_ENV", "test") + e.T().Setenv("APP_URL", "localhost") + e.T().Setenv("PASSWORD_SALT", "salty-password") + e.T().Setenv("NOTE_PASSWORD_SALT", "salty-noted-password") + e.T().Setenv("JWT_SIGNING_KEY", "jwt-key") + e.T().Setenv("LOG_SHOW_LINE", "true") + e.T().Setenv("LOG_FORMAT", "text") + e.T().Setenv("LOG_LEVEL", "debug") + + return config.NewConfig() }
M e2e/e2e_utils_db_test.go

@@ -10,7 +10,8 @@ "github.com/jackc/pgx/v5"

"github.com/olexsmir/onasty/internal/models" ) -func (e *AppTestSuite) getUserFromDBByUsername(username string) models.User { +// getUserByUsername queries user from db by it's username +func (e *AppTestSuite) getUserByUsername(username string) models.User { query, args, err := pgq. Select("id", "username", "email", "password", "created_at", "last_login_at"). From("users").

@@ -26,7 +27,8 @@

return user } -func (e *AppTestSuite) insertUserIntoDB(uname, email, passwd string, activated ...bool) uuid.UUID { +// insertUser inserts user into db +func (e *AppTestSuite) insertUser(uname, email, passwd string, activated ...bool) uuid.UUID { p, err := e.hasher.Hash(passwd) e.require.NoError(err)

@@ -50,7 +52,8 @@

return id } -func (e *AppTestSuite) getLastUserSessionByUserID(uid uuid.UUID) models.Session { +// getLastSessionByUserID gets last inserted [models.Session] for particular user +func (e *AppTestSuite) getLastSessionByUserID(uid uuid.UUID) models.Session { query, args, err := pgq. Select("refresh_token", "expires_at"). From("sessions").

@@ -67,12 +70,14 @@ return models.Session{} //nolint:exhaustruct

} e.require.NoError(err) + session.UserID = uid return session } -func (e *AppTestSuite) getLastInsertedUserByEmail(em string) models.User { +// getLastUserByEmail gets last inserted [models.User] by user's email +func (e *AppTestSuite) getLastUserByEmail(em string) models.User { query, args, err := pgq. - Select("id", "username", "activated", "email", "password"). + Select("id", "username", "activated", "email", "password", "created_at", "last_login_at"). From("users"). Where(pgq.Eq{"email": em}). OrderBy("created_at DESC").

@@ -82,7 +87,7 @@ e.require.NoError(err)

var u models.User err = e.postgresDB.QueryRow(e.ctx, query, args...). - Scan(&u.ID, &u.Username, &u.Activated, &u.Email, &u.Password) + Scan(&u.ID, &u.Username, &u.Activated, &u.Email, &u.Password, &u.CreatedAt, &u.LastLoginAt) if errors.Is(err, pgx.ErrNoRows) { return models.User{} //nolint:exhaustruct }

@@ -91,19 +96,8 @@ e.require.NoError(err)

return u } -type noteModel struct { - ID uuid.UUID - Content string - Slug string - BurnBeforeExpiration bool - Password string - IsRead bool - ReadAt *time.Time - CreatedAt time.Time - ExpiresAt time.Time -} - -func (e *AppTestSuite) getNoteFromDBbySlug(slug string) noteModel { +// getNoteBySlug gets [models.Note] by slug +func (e *AppTestSuite) getNoteBySlug(slug string) models.Note { query, args, err := pgq. Select( "id",

@@ -119,11 +113,11 @@ Where(pgq.Eq{"slug": slug}).

SQL() e.require.NoError(err) - var note noteModel + var note models.Note err = e.postgresDB.QueryRow(e.ctx, query, args...). Scan(&note.ID, &note.Content, &note.Slug, &note.BurnBeforeExpiration, &note.ReadAt, &note.CreatedAt, &note.ExpiresAt) if errors.Is(err, pgx.ErrNoRows) { - return noteModel{} //nolint:exhaustruct + return models.Note{} //nolint:exhaustruct } e.require.NoError(err)
M e2e/e2e_utils_test.go

@@ -61,7 +61,7 @@ e.require.NoError(err)

return u.String() } -// parseJwtToken util func that parses jwt token and returns payload +// parseJwtToken gets payload from the jwt token func (e *AppTestSuite) parseJwtToken(t string) jwtutil.Payload { r, err := e.jwtTokenizer.Parse(t) e.require.NoError(err)
M internal/config/config.go

@@ -8,14 +8,18 @@ "time"

) type Config struct { - AppEnv string - AppURL string - ServerPort string - NatsURL string + AppEnv string + AppURL string + NatsURL string + + HTTPPort string + HTTPWriteTimeout time.Duration + HTTPReadTimeout time.Duration + HTTPHeaderMaxSizeMb int PostgresDSN string PasswordSalt string - NotePassowrdSalt string + NotePasswordSalt string RedisAddr string RedisPassword string

@@ -44,14 +48,18 @@ }

func NewConfig() *Config { return &Config{ - AppEnv: getenvOrDefault("APP_ENV", "debug"), - AppURL: getenvOrDefault("APP_URL", ""), - ServerPort: getenvOrDefault("SERVER_PORT", "3000"), - NatsURL: getenvOrDefault("NATS_URL", ""), + AppEnv: getenvOrDefault("APP_ENV", "debug"), + AppURL: getenvOrDefault("APP_URL", ""), + NatsURL: getenvOrDefault("NATS_URL", ""), + + HTTPPort: getenvOrDefault("HTTP_PORT", "3000"), + HTTPWriteTimeout: mustParseDuration(getenvOrDefault("HTTP_WRITE_TIMEOUT", "10s")), + HTTPReadTimeout: mustParseDuration(getenvOrDefault("HTTP_READ_TIMEOUT", "10s")), + HTTPHeaderMaxSizeMb: mustGetenvOrDefaultInt("HTTP_HEADER_MAX_SIZE_MB", 1), PostgresDSN: getenvOrDefault("POSTGRESQL_DSN", ""), PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""), - NotePassowrdSalt: getenvOrDefault("NOTE_PASSWORD_SALT", ""), + NotePasswordSalt: getenvOrDefault("NOTE_PASSWORD_SALT", ""), RedisAddr: getenvOrDefault("REDIS_ADDR", ""), RedisPassword: getenvOrDefault("REDIS_PASSWORD", ""),
M internal/dtos/note.go

@@ -6,23 +6,19 @@

"github.com/gofrs/uuid/v5" ) -type NoteSlugDTO = string +type NoteSlug = string -type NoteDTO struct { - Content string - Slug string - BurnBeforeExpiration bool - Password string - IsRead bool - ReadAt *time.Time - CreatedAt time.Time - ExpiresAt time.Time +type GetNote struct { + Content string + ReadAt time.Time + CreatedAt time.Time + ExpiresAt time.Time } -type CreateNoteDTO struct { +type CreateNote struct { Content string UserID uuid.UUID - Slug string + Slug NoteSlug BurnBeforeExpiration bool Password string CreatedAt time.Time
D

@@ -1,6 +0,0 @@

-package dtos - -type TokensDTO struct { - Access string - Refresh string -}
M internal/dtos/user.go

@@ -2,34 +2,27 @@ package dtos

import ( "time" - - "github.com/gofrs/uuid/v5" ) -type UserDTO struct { - ID uuid.UUID +type SignUp struct { Username string Email string Password string - Activated bool CreatedAt time.Time LastLoginAt time.Time } -type ResetUserPasswordDTO struct { +type SignIn struct { + Email string + Password string +} + +type ChangeUserPassword struct { CurrentPassword string NewPassword string } -type CreateUserDTO struct { - Username string - Email string - Password string - CreatedAt time.Time - LastLoginAt time.Time -} - -type SignInDTO struct { - Email string - Password string +type Tokens struct { + Access string + Refresh string }
M internal/events/events.go

@@ -11,8 +11,6 @@ natsHeaderErrorCode = "Nats-Service-Error-Code"

natsHeaderErrorMsg = "Nats-Service-Error" ) -var _ error = (*Error)(nil) - type Error struct { Code string Message string
M internal/events/mailermq/mailermq.go

@@ -17,8 +17,6 @@ type MailerMQ struct {

nc *nats.Conn } -const sendMailSubject = "mailer.send" - func New(nc *nats.Conn) *MailerMQ { return &MailerMQ{ nc: nc,

@@ -53,7 +51,7 @@ if err != nil {

return err } - resp, err := m.nc.RequestWithContext(ctx, sendMailSubject, req) + resp, err := m.nc.RequestWithContext(ctx, "mailer.send", req) if err != nil { return err }
M internal/hasher/hasher.go

@@ -1,6 +1,14 @@

package hasher +import "errors" + +var ErrMismatchedHashes = errors.New("hashes are mismatched") + type Hasher interface { // Hash takes a string as input and returns its hash Hash(str string) (string, error) + + // Compare takes two hashes and compares them + // in case of mismatch returns [ErrMismatchedHashes] + Compare(hash, plain string) error }
M internal/hasher/sha256.go

@@ -20,3 +20,15 @@ return "", err

} return hex.EncodeToString(hash.Sum([]byte(h.salt))), nil } + +func (h *SHA256Hasher) Compare(hash, plain string) error { + expected, err := h.Hash(plain) + if err != nil { + return err + } + + if expected != hash { + return ErrMismatchedHashes + } + return nil +}
A internal/hasher/sha256_test.go

@@ -0,0 +1,38 @@

+package hasher + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSHA256Hasher_Hash(t *testing.T) { + hasher := NewSHA256Hasher("salt") + + hashed, err := hasher.Hash("qwerty123") + require.NoError(t, err) + require.NotEmpty(t, hashed) +} + +func TestSHA256Hasher_Compared(t *testing.T) { + hasher := NewSHA256Hasher("salt") + input := "qwerty123" + + t.Run("valid", func(t *testing.T) { + hashed, err := hasher.Hash(input) + require.NoError(t, err) + require.NotEmpty(t, hashed) + + err = hasher.Compare(hashed, input) + require.NoError(t, err) + }) + + t.Run("hashes mismatch", func(t *testing.T) { + hashed, err := hasher.Hash(input + "4") + require.NoError(t, err) + require.NotEmpty(t, hashed) + + err = hasher.Compare(hashed, input) + require.ErrorIs(t, err, ErrMismatchedHashes) + }) +}
M internal/jwtutil/jwtutil.go

@@ -12,16 +12,17 @@

var ErrUnexpectedSigningMethod = errors.New("unexpected signing method") type JWTTokenizer interface { - // AccessToken generates a new access token with the given payload + // AccessToken generates a new access token with the given [Payload]. AccessToken(pl Payload) (string, error) - // RefreshToken generates a new refresh token + // RefreshToken generates a random string of 64 chars. RefreshToken() (string, error) - // Parse parses the token and returns the payload + // Parse parses the token and returns its [Payload]. Parse(token string) (Payload, error) } +// Payload the access token payload type Payload struct { UserID string }
A internal/jwtutil/jwtutil_test.go

@@ -0,0 +1,60 @@

+package jwtutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJWTUtil_AccessToken(t *testing.T) { + jwt := NewJWTUtil("key", time.Hour) + payload := Payload{UserID: "user.123"} + + token, err := jwt.AccessToken(payload) + require.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJWTUtil_RefreshToken(t *testing.T) { + jwt := NewJWTUtil("key", time.Hour) + + tok, err := jwt.RefreshToken() + require.NoError(t, err) + assert.Len(t, tok, 64) + + secondTok, err := jwt.RefreshToken() + require.NoError(t, err) + + // tokens should be unique + assert.NotEqual(t, tok, secondTok) +} + +func TestJWTUtil_Parse(t *testing.T) { + jwt := NewJWTUtil("key", time.Hour) + payload := Payload{UserID: "qwerty"} + + token, err := jwt.AccessToken(payload) + require.NoError(t, err) + assert.NotEmpty(t, token) + + parsedPayload, err := jwt.Parse(token) + require.NoError(t, err) + + assert.Equal(t, payload, parsedPayload) +} + +func TestJWTUtil_Parse_expired(t *testing.T) { + ttl := 100 * time.Millisecond + jwt := NewJWTUtil("key", ttl) + payload := Payload{UserID: "qwerty"} + + token, err := jwt.AccessToken(payload) + require.NoError(t, err) + assert.NotEmpty(t, token) + + time.Sleep(ttl) + _, err = jwt.Parse(token) + require.Error(t, err) +}
M internal/models/user.go

@@ -11,13 +11,17 @@

var ( ErrUserEmailIsAlreadyInUse = errors.New("user: email is already in use") ErrUsernameIsAlreadyInUse = errors.New("user: username is already in use") - ErrUserIsAlreeadyVerified = errors.New("user: user is already verified") + ErrUserIsAlreadyVerified = errors.New("user: user is already verified") ErrVerificationTokenNotFound = errors.New("user: verification token not found") ErrUserIsNotActivated = errors.New("user: user is not activated") ErrUserNotFound = errors.New("user: not found") ErrUserWrongCredentials = errors.New("user: wrong credentials") + + ErrUserInvalidEmail = errors.New("user: invalid email") + ErrUserInvalidPassword = errors.New("user: password too short, minimum 6 chars") + ErrUserInvalidUsername = errors.New("user: username is required") ) type User struct {

@@ -33,16 +37,20 @@

func (u User) Validate() error { _, err := mail.ParseAddress(u.Email) if err != nil { - return errors.New("user: invalid email") //nolint:err113 + return ErrUserInvalidEmail } if len(u.Password) < 6 { - return errors.New("user: password too short, minimum 6 chars") //nolint:err113 + return ErrUserInvalidPassword } if len(u.Username) == 0 { - return errors.New("user: username is required") //nolint:err113 + return ErrUserInvalidUsername } return nil } + +func (u User) IsActivated() bool { + return u.Activated +}
M internal/service/notesrv/input.go

@@ -5,7 +5,7 @@

// GetNoteBySlugInput used as input for [GetBySlugAndRemoveIfNeeded] type GetNoteBySlugInput struct { // Slug is a note's slug :) *Required* - Slug dtos.NoteSlugDTO + Slug dtos.NoteSlug // Password is a note's password. // Optional, needed only if note has one.
M internal/service/notesrv/notesrv.go

@@ -17,10 +17,13 @@ type NoteServicer interface {

// Create creates note // if slug is empty it will be generated, otherwise used as is // if userID is empty it means user isn't authorized so it will be used - Create(ctx context.Context, note dtos.CreateNoteDTO, userID uuid.UUID) (dtos.NoteSlugDTO, error) + Create(ctx context.Context, note dtos.CreateNote, userID uuid.UUID) (dtos.NoteSlug, error) // GetBySlugAndRemoveIfNeeded returns note by slug, and removes if if needed - GetBySlugAndRemoveIfNeeded(ctx context.Context, input GetNoteBySlugInput) (dtos.NoteDTO, error) + GetBySlugAndRemoveIfNeeded( + ctx context.Context, + input GetNoteBySlugInput, + ) (dtos.GetNote, error) } var _ NoteServicer = (*NoteSrv)(nil)

@@ -41,9 +44,9 @@ }

func (n *NoteSrv) Create( ctx context.Context, - inp dtos.CreateNoteDTO, + inp dtos.CreateNote, userID uuid.UUID, -) (dtos.NoteSlugDTO, error) { +) (dtos.NoteSlug, error) { slog.DebugContext(ctx, "creating", "inp", inp) if inp.Slug == "" {

@@ -58,7 +61,20 @@ }

inp.Password = hashedPassword } - if err := n.noterepo.Create(ctx, inp); err != nil { + //nolint:exhaustruct // ID - cannot be predicted, and ReadAt will be set on read + note := models.Note{ + Content: inp.Content, + Slug: inp.Slug, + Password: inp.Password, + BurnBeforeExpiration: inp.BurnBeforeExpiration, + CreatedAt: inp.CreatedAt, + ExpiresAt: inp.ExpiresAt, + } + if err := note.Validate(); err != nil { + return "", err + } + + if err := n.noterepo.Create(ctx, note); err != nil { return "", err }

@@ -74,41 +90,43 @@

func (n *NoteSrv) GetBySlugAndRemoveIfNeeded( ctx context.Context, inp GetNoteBySlugInput, -) (dtos.NoteDTO, error) { +) (dtos.GetNote, error) { note, err := n.getNote(ctx, inp) if err != nil { - return dtos.NoteDTO{}, err + return dtos.GetNote{}, err } - m := models.Note{ //nolint:exhaustruct - ExpiresAt: note.ExpiresAt, - BurnBeforeExpiration: note.BurnBeforeExpiration, + if note.IsExpired() { + return dtos.GetNote{}, models.ErrNoteExpired } - if m.IsExpired() { - return dtos.NoteDTO{}, models.ErrNoteExpired + respNote := dtos.GetNote{ + Content: note.Content, + ReadAt: note.ReadAt, + CreatedAt: note.CreatedAt, + ExpiresAt: note.ExpiresAt, } // since not every note should be burn before expiration // we return early if it's not - if m.ShouldBeBurnt() { - return note, nil + if note.ShouldBeBurnt() { + return respNote, nil } - return note, n.noterepo.RemoveBySlug(ctx, inp.Slug, time.Now()) + return respNote, n.noterepo.RemoveBySlug(ctx, inp.Slug, time.Now()) } -func (n *NoteSrv) getNote(ctx context.Context, inp GetNoteBySlugInput) (dtos.NoteDTO, error) { +func (n *NoteSrv) getNote(ctx context.Context, inp GetNoteBySlugInput) (models.Note, error) { if r, err := n.cache.GetNote(ctx, inp.Slug); err == nil { return r, nil } note, err := n.getNoteFromDBasedOnInput(ctx, inp) if err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } - if note.ReadAt != nil && !note.ReadAt.IsZero() { + if !note.IsRead() { if err = n.cache.SetNote(ctx, inp.Slug, note); err != nil { slog.ErrorContext(ctx, "notecache", "err", err) }

@@ -120,11 +138,11 @@

func (n *NoteSrv) getNoteFromDBasedOnInput( ctx context.Context, inp GetNoteBySlugInput, -) (dtos.NoteDTO, error) { +) (models.Note, error) { if inp.HasPassword() { hashedPassword, err := n.hasher.Hash(inp.Password) if err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } return n.noterepo.GetBySlugAndPassword(ctx, inp.Slug, hashedPassword)
M internal/service/usersrv/usersrv.go

@@ -19,15 +19,15 @@ "github.com/olexsmir/onasty/internal/store/rdb/usercache"

) type UserServicer interface { - SignUp(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) - SignIn(ctx context.Context, inp dtos.SignInDTO) (dtos.TokensDTO, error) - RefreshTokens(ctx context.Context, refreshToken string) (dtos.TokensDTO, error) + SignUp(ctx context.Context, inp dtos.SignUp) (uuid.UUID, error) + SignIn(ctx context.Context, inp dtos.SignIn) (dtos.Tokens, error) + RefreshTokens(ctx context.Context, refreshToken string) (dtos.Tokens, error) Logout(ctx context.Context, userID uuid.UUID) error - ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ResetUserPasswordDTO) error + ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ChangeUserPassword) error Verify(ctx context.Context, verificationKey string) error - ResendVerificationEmail(ctx context.Context, credentials dtos.SignInDTO) error + ResendVerificationEmail(ctx context.Context, credentials dtos.SignIn) error ParseJWTToken(token string) (jwtutil.Payload, error)

@@ -73,66 +73,78 @@ verificationTokenTTL: verificationTokenTTL,

} } -func (u *UserSrv) SignUp(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) { +func (u *UserSrv) SignUp(ctx context.Context, inp dtos.SignUp) (uuid.UUID, error) { hashedPassword, err := u.hasher.Hash(inp.Password) if err != nil { return uuid.UUID{}, err } - uid, err := u.userstore.Create(ctx, dtos.CreateUserDTO{ + user := models.User{ + ID: uuid.Nil, // nil, because it does not get used here Username: inp.Username, Email: inp.Email, + Activated: false, Password: hashedPassword, CreatedAt: inp.CreatedAt, LastLoginAt: inp.LastLoginAt, - }) + } + if err = user.Validate(); err != nil { + return uuid.Nil, err + } + + userID, err := u.userstore.Create(ctx, user) if err != nil { return uuid.Nil, err } - vtok := uuid.Must(uuid.NewV4()).String() - if err := u.vertokrepo.Create(ctx, vtok, uid, time.Now(), time.Now().Add(u.verificationTokenTTL)); err != nil { + verificationToken := uuid.Must(uuid.NewV4()).String() + if err := u.vertokrepo.Create( + ctx, + verificationToken, + userID, + time.Now(), + time.Now().Add(u.verificationTokenTTL), + ); err != nil { return uuid.Nil, err } if err := u.mailermq.SendVerificationEmail(ctx, mailermq.SendVerificationEmailRequest{ Receiver: inp.Email, - Token: vtok, + Token: verificationToken, }); err != nil { return uuid.Nil, err } - return uid, nil + return userID, nil } -func (u *UserSrv) SignIn(ctx context.Context, inp dtos.SignInDTO) (dtos.TokensDTO, error) { - hashedPassword, err := u.hasher.Hash(inp.Password) +func (u *UserSrv) SignIn(ctx context.Context, inp dtos.SignIn) (dtos.Tokens, error) { + user, err := u.userstore.GetByEmail(ctx, inp.Email) if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - user, err := u.userstore.GetUserByCredentials(ctx, inp.Email, hashedPassword) - if err != nil { - if errors.Is(err, models.ErrUserNotFound) { - return dtos.TokensDTO{}, models.ErrUserWrongCredentials + if err = u.hasher.Compare(user.Password, inp.Password); err != nil { + if errors.Is(err, hasher.ErrMismatchedHashes) { + return dtos.Tokens{}, models.ErrUserWrongCredentials } - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - if !user.Activated { - return dtos.TokensDTO{}, models.ErrUserIsNotActivated + if !user.IsActivated() { + return dtos.Tokens{}, models.ErrUserIsNotActivated } - tokens, err := u.getTokens(user.ID) + tokens, err := u.createTokens(user.ID) if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } if err := u.sessionstore.Set(ctx, user.ID, tokens.Refresh, time.Now().Add(u.refreshTokenTTL)); err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - return dtos.TokensDTO{ + return dtos.Tokens{ Access: tokens.Access, Refresh: tokens.Refresh, }, nil

@@ -142,22 +154,22 @@ func (u *UserSrv) Logout(ctx context.Context, userID uuid.UUID) error {

return u.sessionstore.Delete(ctx, userID) } -func (u *UserSrv) RefreshTokens(ctx context.Context, rtoken string) (dtos.TokensDTO, error) { +func (u *UserSrv) RefreshTokens(ctx context.Context, rtoken string) (dtos.Tokens, error) { userID, err := u.sessionstore.GetUserIDByRefreshToken(ctx, rtoken) if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - tokens, err := u.getTokens(userID) + tokens, err := u.createTokens(userID) if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } if err := u.sessionstore.Update(ctx, userID, rtoken, tokens.Refresh); err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - return dtos.TokensDTO{ + return dtos.Tokens{ Access: tokens.Access, Refresh: tokens.Refresh, }, nil

@@ -166,8 +178,10 @@

func (u *UserSrv) ChangePassword( ctx context.Context, userID uuid.UUID, - inp dtos.ResetUserPasswordDTO, + inp dtos.ChangeUserPassword, ) error { + // TODO: compare current password with providede, and assert on mismatch + oldPass, err := u.hasher.Hash(inp.CurrentPassword) if err != nil { return err

@@ -194,22 +208,18 @@

return u.userstore.MarkUserAsActivated(ctx, uid) } -func (u *UserSrv) ResendVerificationEmail(ctx context.Context, inp dtos.SignInDTO) error { - hashedPassword, err := u.hasher.Hash(inp.Password) +func (u *UserSrv) ResendVerificationEmail(ctx context.Context, inp dtos.SignIn) error { + user, err := u.userstore.GetByEmail(ctx, inp.Email) if err != nil { return err } - user, err := u.userstore.GetUserByCredentials(ctx, inp.Email, hashedPassword) - if err != nil { - if errors.Is(err, models.ErrUserNotFound) { - return models.ErrUserWrongCredentials - } - return err + if err = u.hasher.Compare(user.Password, inp.Password); err != nil { + return models.ErrUserWrongCredentials } if user.Activated { - return models.ErrUserIsAlreeadyVerified + return models.ErrUserIsAlreadyVerified } token, err := u.vertokrepo.GetTokenOrUpdateTokenByUserID(

@@ -236,11 +246,12 @@ return u.jwtTokenizer.Parse(token)

} func (u UserSrv) CheckIfUserExists(ctx context.Context, id uuid.UUID) (bool, error) { - if r, err := u.cache.GetIsExists(ctx, id.String()); err == nil { + r, err := u.cache.GetIsExists(ctx, id.String()) + if err == nil { return r, nil - } else { //nolint:revive - slog.ErrorContext(ctx, "usercache", "err", err) } + + slog.ErrorContext(ctx, "usercache", "err", err) isExists, err := u.userstore.CheckIfUserExists(ctx, id) if err != nil {

@@ -248,43 +259,44 @@ return false, err

} if err := u.cache.SetIsExists(ctx, id.String(), isExists); err != nil { - slog.Error("usercache", "err", err) + slog.ErrorContext(ctx, "usercache", "err", err) } return isExists, nil } -func (u UserSrv) CheckIfUserIsActivated(ctx context.Context, userID uuid.UUID) (bool, error) { - if r, err := u.cache.GetIsActivated(ctx, userID.String()); err == nil { +func (u *UserSrv) CheckIfUserIsActivated(ctx context.Context, userID uuid.UUID) (bool, error) { + r, err := u.cache.GetIsActivated(ctx, userID.String()) + if err == nil { return r, nil - } else { //nolint:revive - slog.ErrorContext(ctx, "usercache", "err", err) } - isActivated, err := u.userstore.CheckIfUserExists(ctx, userID) + slog.ErrorContext(ctx, "usercache", "err", err) + + isActivated, err := u.userstore.CheckIfUserIsActivated(ctx, userID) if err != nil { return false, err } if err := u.cache.SetIsActivated(ctx, userID.String(), isActivated); err != nil { - slog.Error("usercache", "err", err) + slog.ErrorContext(ctx, "usercache", "err", err) } return isActivated, nil } -func (u UserSrv) getTokens(userID uuid.UUID) (dtos.TokensDTO, error) { +func (u UserSrv) createTokens(userID uuid.UUID) (dtos.Tokens, error) { accessToken, err := u.jwtTokenizer.AccessToken(jwtutil.Payload{UserID: userID.String()}) if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } refreshToken, err := u.jwtTokenizer.RefreshToken() if err != nil { - return dtos.TokensDTO{}, err + return dtos.Tokens{}, err } - return dtos.TokensDTO{ + return dtos.Tokens{ Access: accessToken, Refresh: refreshToken, }, err
M internal/store/psql/noterepo/noterepo.go

@@ -15,11 +15,11 @@ )

type NoteStorer interface { // Create creates a note. - Create(ctx context.Context, inp dtos.CreateNoteDTO) error + Create(ctx context.Context, note models.Note) error // GetBySlug gets a note by slug. // Returns [models.ErrNoteNotFound] if note is not found. - GetBySlug(ctx context.Context, slug dtos.NoteSlugDTO) (dtos.NoteDTO, error) + GetBySlug(ctx context.Context, slug dtos.NoteSlug) (models.Note, error) // GetBySlugAndPassword gets a note by slug and password. // the "password" should be hashed.

@@ -27,17 +27,17 @@ //

// Returns [models.ErrNoteNotFound] if note is not found. GetBySlugAndPassword( ctx context.Context, - slug dtos.NoteSlugDTO, + slug dtos.NoteSlug, password string, - ) (dtos.NoteDTO, error) + ) (models.Note, error) // RemoveBySlug marks note as read, deletes it's content, and keeps meta data // Returns [models.ErrNoteNotFound] if note is not found. - RemoveBySlug(ctx context.Context, slug dtos.NoteSlugDTO, readAt time.Time) error + RemoveBySlug(ctx context.Context, slug dtos.NoteSlug, readAt time.Time) error // SetAuthorIDBySlug assigns author to note by slug. // Returns [models.ErrNoteNotFound] if note is not found. - SetAuthorIDBySlug(ctx context.Context, slug dtos.NoteSlugDTO, authorID uuid.UUID) error + SetAuthorIDBySlug(ctx context.Context, slug dtos.NoteSlug, authorID uuid.UUID) error } var _ NoteStorer = (*NoteRepo)(nil)

@@ -50,10 +50,10 @@ func New(db *psqlutil.DB) *NoteRepo {

return &NoteRepo{db} } -func (s *NoteRepo) Create(ctx context.Context, inp dtos.CreateNoteDTO) error { +func (s *NoteRepo) Create(ctx context.Context, inp models.Note) error { query, args, err := pgq. Insert("notes"). - Columns("content", "slug", "password", "burn_before_expiration ", "created_at", "expires_at"). + Columns("content", "slug", "password", "burn_before_expiration", "created_at", "expires_at"). Values(inp.Content, inp.Slug, inp.Password, inp.BurnBeforeExpiration, inp.CreatedAt, inp.ExpiresAt). SQL() if err != nil {

@@ -68,7 +68,7 @@

return err } -func (s *NoteRepo) GetBySlug(ctx context.Context, slug dtos.NoteSlugDTO) (dtos.NoteDTO, error) { +func (s *NoteRepo) GetBySlug(ctx context.Context, slug dtos.NoteSlug) (models.Note, error) { query, args, err := pgq. Select("content", "slug", "burn_before_expiration", "read_at", "created_at", "expires_at"). From("notes").

@@ -76,15 +76,15 @@ Where("(password is null or password = '')").

Where(pgq.Eq{"slug": slug}). SQL() if err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } - var note dtos.NoteDTO + var note models.Note err = s.db.QueryRow(ctx, query, args...). Scan(&note.Content, &note.Slug, &note.BurnBeforeExpiration, &note.ReadAt, &note.CreatedAt, &note.ExpiresAt) if errors.Is(err, pgx.ErrNoRows) { - return dtos.NoteDTO{}, models.ErrNoteNotFound + return models.Note{}, models.ErrNoteNotFound } return note, err

@@ -92,9 +92,9 @@ }

func (s *NoteRepo) GetBySlugAndPassword( ctx context.Context, - slug dtos.NoteSlugDTO, + slug dtos.NoteSlug, passwd string, -) (dtos.NoteDTO, error) { +) (models.Note, error) { query, args, err := pgq. Select("content", "slug", "burn_before_expiration", "read_at", "created_at", "expires_at"). From("notes").

@@ -104,15 +104,15 @@ "password": passwd,

}). SQL() if err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } - var note dtos.NoteDTO + var note models.Note err = s.db.QueryRow(ctx, query, args...). Scan(&note.Content, &note.Slug, &note.BurnBeforeExpiration, &note.ReadAt, &note.CreatedAt, &note.ExpiresAt) if errors.Is(err, pgx.ErrNoRows) { - return dtos.NoteDTO{}, models.ErrNoteNotFound + return models.Note{}, models.ErrNoteNotFound } return note, err

@@ -120,7 +120,7 @@ }

func (s *NoteRepo) RemoveBySlug( ctx context.Context, - slug dtos.NoteSlugDTO, + slug dtos.NoteSlug, readAt time.Time, ) error { query, args, err := pgq.

@@ -129,7 +129,7 @@ Set("content", "").

Set("read_at", readAt). Where(pgq.Eq{ "slug": slug, - "read_at": nil, + "read_at": time.Time{}, // check if time is null }). SQL() if err != nil {

@@ -146,7 +146,7 @@ }

func (s *NoteRepo) SetAuthorIDBySlug( ctx context.Context, - slug dtos.NoteSlugDTO, + slug dtos.NoteSlug, authorID uuid.UUID, ) error { tx, err := s.db.Begin(ctx)
M internal/store/psql/userepo/userepo.go

@@ -7,17 +7,16 @@

"github.com/gofrs/uuid/v5" "github.com/henvic/pgq" "github.com/jackc/pgx/v5" - "github.com/olexsmir/onasty/internal/dtos" "github.com/olexsmir/onasty/internal/models" "github.com/olexsmir/onasty/internal/store/psqlutil" ) type UserStorer interface { - Create(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) + Create(ctx context.Context, inp models.User) (uuid.UUID, error) // GetUserByCredentials returns user by email and password // the password should be hashed - GetUserByCredentials(ctx context.Context, email, password string) (dtos.UserDTO, error) + GetByEmail(ctx context.Context, email string) (models.User, error) GetUserIDByEmail(ctx context.Context, email string) (uuid.UUID, error) MarkUserAsActivated(ctx context.Context, id uuid.UUID) error

@@ -46,11 +45,11 @@ db: db,

} } -func (r *UserRepo) Create(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) { +func (r *UserRepo) Create(ctx context.Context, inp models.User) (uuid.UUID, error) { query, args, err := pgq. Insert("users"). - Columns("username", "email", "password", "created_at", "last_login_at"). - Values(inp.Username, inp.Email, inp.Password, inp.CreatedAt, inp.LastLoginAt). + Columns("username", "email", "password", "activated", "created_at", "last_login_at"). + Values(inp.Username, inp.Email, inp.Password, inp.Activated, inp.CreatedAt, inp.LastLoginAt). Returning("id"). SQL() if err != nil {

@@ -72,27 +71,24 @@

return id, err } -func (r *UserRepo) GetUserByCredentials( +func (r *UserRepo) GetByEmail( ctx context.Context, - email, password string, -) (dtos.UserDTO, error) { + email string, +) (models.User, error) { query, args, err := pgq. Select("id", "username", "email", "password", "activated", "created_at", "last_login_at"). From("users"). - Where(pgq.Eq{ - "email": email, - "password": password, - }). + Where(pgq.Eq{"email": email}). SQL() if err != nil { - return dtos.UserDTO{}, err + return models.User{}, err } - var user dtos.UserDTO + var user models.User err = r.db.QueryRow(ctx, query, args...). Scan(&user.ID, &user.Username, &user.Email, &user.Password, &user.Activated, &user.CreatedAt, &user.LastLoginAt) if errors.Is(err, pgx.ErrNoRows) { - return dtos.UserDTO{}, models.ErrUserNotFound + return models.User{}, models.ErrUserNotFound } return user, err
M internal/store/psql/vertokrepo/vertokrepo.go

@@ -82,7 +82,7 @@ return uuid.Nil, err

} if isUsed { - return uuid.Nil, models.ErrUserIsAlreeadyVerified + return uuid.Nil, models.ErrUserIsAlreadyVerified } query := `--sql
M internal/store/rdb/notecache/notecache.go

@@ -7,13 +7,13 @@ "encoding/gob"

"strings" "time" - "github.com/olexsmir/onasty/internal/dtos" + "github.com/olexsmir/onasty/internal/models" "github.com/olexsmir/onasty/internal/store/rdb" ) type NoteCacher interface { - SetNote(ctx context.Context, slug string, note dtos.NoteDTO) error - GetNote(ctx context.Context, slug string) (dtos.NoteDTO, error) + SetNote(ctx context.Context, slug string, note models.Note) error + GetNote(ctx context.Context, slug string) (models.Note, error) } type NoteCache struct {

@@ -28,7 +28,7 @@ ttl: ttl,

} } -func (n *NoteCache) SetNote(ctx context.Context, slug string, note dtos.NoteDTO) error { +func (n *NoteCache) SetNote(ctx context.Context, slug string, note models.Note) error { var buf bytes.Buffer if err := gob.NewEncoder(&buf).Encode(note); err != nil { return err

@@ -38,15 +38,15 @@ _, err := n.rdb.Set(ctx, getKey(slug), buf.Bytes(), n.ttl).Result()

return err } -func (n *NoteCache) GetNote(ctx context.Context, slug string) (dtos.NoteDTO, error) { +func (n *NoteCache) GetNote(ctx context.Context, slug string) (models.Note, error) { val, err := n.rdb.Get(ctx, getKey(slug)).Bytes() if err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } - var note dtos.NoteDTO + var note models.Note if err = gob.NewDecoder(bytes.NewReader(val)).Decode(&note); err != nil { - return dtos.NoteDTO{}, err + return models.Note{}, err } return note, err
M internal/transport/http/apiv1/auth.go

@@ -6,7 +6,6 @@ "time"

"github.com/gin-gonic/gin" "github.com/olexsmir/onasty/internal/dtos" - "github.com/olexsmir/onasty/internal/models" ) type signUpRequest struct {

@@ -22,25 +21,12 @@ newError(c, http.StatusBadRequest, "invalid request")

return } - user := models.User{ //nolint:exhaustruct + if _, err := a.usersrv.SignUp(c.Request.Context(), dtos.SignUp{ Username: req.Username, Email: req.Email, Password: req.Password, CreatedAt: time.Now(), LastLoginAt: time.Now(), - } - if err := user.Validate(); err != nil { - // TODO: find a way to return all errors at once - newErrorStatus(c, http.StatusBadRequest, err.Error()) - return - } - - if _, err := a.usersrv.SignUp(c.Request.Context(), dtos.CreateUserDTO{ - Username: user.Username, - Email: user.Email, - Password: user.Password, - CreatedAt: user.CreatedAt, - LastLoginAt: user.LastLoginAt, }); err != nil { errorResponse(c, err) return

@@ -66,7 +52,7 @@ newError(c, http.StatusBadRequest, "invalid request")

return } - toks, err := a.usersrv.SignIn(c.Request.Context(), dtos.SignInDTO{ + toks, err := a.usersrv.SignIn(c.Request.Context(), dtos.SignIn{ Email: req.Email, Password: req.Password, })

@@ -120,10 +106,12 @@ newError(c, http.StatusBadRequest, "invalid request")

return } - if err := a.usersrv.ResendVerificationEmail(c.Request.Context(), dtos.SignInDTO{ - Email: req.Email, - Password: req.Password, - }); err != nil { + if err := a.usersrv.ResendVerificationEmail( + c.Request.Context(), + dtos.SignIn{ + Email: req.Email, + Password: req.Password, + }); err != nil { errorResponse(c, err) return }

@@ -155,7 +143,7 @@

if err := a.usersrv.ChangePassword( c.Request.Context(), a.getUserID(c), - dtos.ResetUserPasswordDTO{ + dtos.ChangeUserPassword{ CurrentPassword: req.CurrentPassword, NewPassword: req.NewPassword, }); err != nil {
M internal/transport/http/apiv1/note.go

@@ -45,7 +45,7 @@ newErrorStatus(c, http.StatusBadRequest, err.Error())

return } - slug, err := a.notesrv.Create(c.Request.Context(), dtos.CreateNoteDTO{ + slug, err := a.notesrv.Create(c.Request.Context(), dtos.CreateNote{ Content: note.Content, UserID: a.getUserID(c), Slug: note.Slug,

@@ -67,10 +67,10 @@ Password string `json:"password,omitempty"`

} type getNoteBySlugResponse struct { - Content string `json:"content,omitempty"` - ReadAt *time.Time `json:"read_at,omitempty"` - CratedAt time.Time `json:"crated_at"` - ExpiresAt time.Time `json:"expires_at"` + Content string `json:"content,omitempty"` + ReadAt time.Time `json:"read_at"` + CratedAt time.Time `json:"crated_at"` + ExpiresAt time.Time `json:"expires_at"` } func (a *APIV1) getNoteBySlugHandler(c *gin.Context) {

@@ -80,11 +80,10 @@ newError(c, http.StatusBadRequest, "invalid request")

return } - slug := c.Param("slug") note, err := a.notesrv.GetBySlugAndRemoveIfNeeded( c.Request.Context(), notesrv.GetNoteBySlugInput{ - Slug: slug, + Slug: c.Param("slug"), Password: req.Password, }, )

@@ -94,7 +93,7 @@ return

} status := http.StatusOK - if note.ReadAt != nil && !note.ReadAt.IsZero() { + if !note.ReadAt.IsZero() { status = http.StatusNotFound }
M internal/transport/http/apiv1/response.go

@@ -18,10 +18,14 @@

func errorResponse(c *gin.Context, err error) { if errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || errors.Is(err, models.ErrUsernameIsAlreadyInUse) || - errors.Is(err, models.ErrNoteContentIsEmpty) || - errors.Is(err, models.ErrNoteSlugIsAlreadyInUse) || + errors.Is(err, models.ErrUserIsAlreadyVerified) || errors.Is(err, models.ErrUserIsNotActivated) || - errors.Is(err, models.ErrUserIsAlreeadyVerified) { + errors.Is(err, models.ErrUserInvalidEmail) || + errors.Is(err, models.ErrUserInvalidPassword) || + errors.Is(err, models.ErrUserInvalidUsername) || + // notes + errors.Is(err, models.ErrNoteContentIsEmpty) || + errors.Is(err, models.ErrNoteSlugIsAlreadyInUse) { newError(c, http.StatusBadRequest, err.Error()) return }
M internal/transport/http/httpserver/httpserver.go

@@ -10,15 +10,28 @@ type Server struct {

http *http.Server } -func NewServer(port string, handler http.Handler) *Server { - // TODO: add those settings to the config module +type Config struct { + // Port http server port + Port string + + // ReadTimeout read timeout + ReadTimeout time.Duration + + // WriteTimeout write timeout + WriteTimeout time.Duration + + // MaxHeaderSizeMb max size of headers in megabytes + MaxHeaderSizeMb int +} + +func NewServer(handler http.Handler, cfg Config) *Server { return &Server{ http: &http.Server{ - Addr: ":" + port, + Addr: ":" + cfg.Port, Handler: handler, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - MaxHeaderBytes: 1 << 20, // 1mb + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + MaxHeaderBytes: cfg.MaxHeaderSizeMb << 20, }, } }
M internal/transport/http/ratelimit/ratelimit.go

@@ -43,8 +43,8 @@ ttl: ttl,

} } -// Retrieve and return the rate limiter for the current visitor if it -// already exists. Otherwise create a new rate limiter and add it to +// getVisitor Retrieve and return the rate limiter for the current visitor +// if it already exists. Otherwise create a new rate limiter and add it to // the visitors map, using the IP address as the key. func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter { r.mu.RLock()

@@ -71,19 +71,24 @@

return v.limiter } -// Every minute check the map for visitors that haven't been seen for -// more than 3 minutes and delete the entries. +// cleanUpVisitors checks the map of visitors that haven't been seed +// for more than [Config].TTL and delete those entries func (r *rateLimiter) cleanupVisitors() { + r.mu.Lock() + defer r.mu.Unlock() + + for ip, v := range r.visitors { + if time.Since(v.lastSeen) > r.ttl { + delete(r.visitors, ip) + } + } +} + +// cleanupVisitorsLoop runs [rateLimiter.cleanupVisitors] every minute +func (r *rateLimiter) cleanupVisitorsLoop() { for { time.Sleep(time.Minute) - - r.mu.Lock() - for ip, v := range r.visitors { - if time.Since(v.lastSeen) > r.ttl { - delete(r.visitors, ip) - } - } - r.mu.Unlock() + r.cleanupVisitors() } }

@@ -101,7 +106,7 @@

// MiddlewareWithConfig returns a new rate limiting middleware with the given config func MiddlewareWithConfig(c Config) gin.HandlerFunc { lmt := newLimiter(c.RPS, c.Burst, c.TTL) - go lmt.cleanupVisitors() + go lmt.cleanupVisitorsLoop() return func(c *gin.Context) { visitor := lmt.getVisitor(visitorIP(c.ClientIP()))
A internal/transport/http/ratelimit/ratelimit_test.go

@@ -0,0 +1,90 @@

+package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestRateLimiter_getVisitor(t *testing.T) { + limiter := newLimiter(10, 20, time.Second) + ip := visitorIP("127.0.0.1") + + visitor := limiter.getVisitor(ip) + assert.NotNil(t, visitor) + + visitorAgain := limiter.getVisitor(ip) + assert.Equal(t, visitor, visitorAgain) + + assert.Len(t, limiter.visitors, 1) +} + +// TODO: rewrite to use "testing/synctest" when it gets merged +func TestRateLimiter_cleanupVisitors(t *testing.T) { + limiter := newLimiter(10, 20, time.Second/2) + limiter.getVisitor("192.168.9.1") + assert.Len(t, limiter.visitors, 1) + + time.Sleep(time.Second) + limiter.cleanupVisitors() + assert.Empty(t, limiter.visitors) +} + +func TestMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + tests := map[string]struct { + config Config + requests int + expectedCode int + }{ + "allows requests with in limit": { + config: Config{ + RPS: 2, + Burst: 2, + TTL: time.Minute, + }, + requests: 1, + expectedCode: http.StatusOK, + }, + "blocks requests over limit": { + config: Config{ + RPS: 1, + Burst: 1, + TTL: time.Minute, + }, + requests: 2, + expectedCode: http.StatusTooManyRequests, + }, + "allows burst requests": { + config: Config{ + RPS: 1, + Burst: 3, + TTL: time.Minute, + }, + requests: 3, + expectedCode: http.StatusOK, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + handler := MiddlewareWithConfig(tt.config) + var lastCode int + + for range tt.requests { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + handler(c) + lastCode = w.Code + } + + assert.Equal(t, tt.expectedCode, lastCode) + }) + } +}
M mailer/main.go

@@ -9,6 +9,7 @@ "os"

"os/signal" "strings" "syscall" + "time" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/micro"

@@ -62,7 +63,12 @@ return err

} if cfg.MetricsEnabled { - srv := httpserver.NewServer(cfg.MetricsPort, MetricsHandler()) + srv := httpserver.NewServer(MetricsHandler(), httpserver.Config{ + Port: cfg.MetricsPort, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderSizeMb: 1, + }) go func() { slog.Info("starting metrics server", "port", cfg.MetricsPort) if err := srv.Start(); !errors.Is(err, http.ErrServerClosed) {
M migrations/20250401121105_notes_add_read.up.sql

@@ -1,2 +1,2 @@

ALTER TABLE notes - ADD COLUMN "read_at" timestamptz; + ADD COLUMN "read_at" timestamptz NOT NULL DEFAULT '0001-01-01 00:00:00';