@@ -1,8 +1,13 @@
APP_ENV="debug" SERVER_PORT=3000 +PASSWORD_SALT="onasty" LOG_LEVEL="debug" LOG_FORMAT="text" + +JWT_SIGNING_KEY="supersecret" +JWT_ACCESS_TOKEN_TTL="30m" +JWT_REFRESH_TOKEN_TTL="15d" POSTGRES_USERNAME="onasty" POSTGRES_PASSWORD="qwerty"
@@ -17,12 +17,6 @@ with:
go-version-file: go.mod cache-dependency-path: go.mod - - name: Golangci Lint - uses: golangci/golangci-lint-action@v3 - with: - version: latest - args: ./... - - name: Build API run: go build -o .bin/onasty ./cmd/server/
@@ -0,0 +1,17 @@
+name: linter + +on: + push: + pull_request: + +jobs: + golang: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Golangci Lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + args: ./...
@@ -20,6 +20,15 @@
docker:up: - docker compose up -d + docker:down: + aliases: [docker:stop] + cmds: + - docker compose stop + + test: + - task: test:unit + - task: test:e2e + test:unit: - go test -v --short ./...
@@ -11,7 +11,10 @@ "os/signal"
"github.com/gin-gonic/gin" "github.com/olexsmir/onasty/internal/config" + "github.com/olexsmir/onasty/internal/hasher" + "github.com/olexsmir/onasty/internal/jwtutil" "github.com/olexsmir/onasty/internal/service/usersrv" + "github.com/olexsmir/onasty/internal/store/psql/sessionrepo" "github.com/olexsmir/onasty/internal/store/psql/userepo" "github.com/olexsmir/onasty/internal/store/psqlutil" httptransport "github.com/olexsmir/onasty/internal/transport/http"@@ -44,8 +47,13 @@ return err
} // app deps + sha256Hasher := hasher.NewSHA256Hasher(cfg.PasswordSalt) + jwtTokenizer := jwtutil.NewJWTUtil(cfg.JwtSigningKey, cfg.JwtAccessTokenTTL) + + sessionrepo := sessionrepo.New(psqlDB) + userepo := userepo.New(psqlDB) - usersrv := usersrv.New(userepo) + usersrv := usersrv.New(userepo, sessionrepo, sha256Hasher, jwtTokenizer) handler := httptransport.NewTransport(usersrv)
@@ -0,0 +1,217 @@
+package e2e + +import ( + "net/http" + + "github.com/gofrs/uuid/v5" +) + +type apiv1AuthSignUpRequest struct { + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` +} + +func (e *AppTestSuite) TestAuthV1_SignUP() { + username := "test" + e.uuid() + email := e.uuid() + "test@test.com" + password := "password" + + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/signup", + e.jsonify(apiv1AuthSignUpRequest{ + Username: username, + Email: email, + Password: password, + }), + ) + + dbUser := e.getUserFromDBByUsername(username) + hashedPasswd, err := e.hasher.Hash(password) + e.require.NoError(err) + + e.Equal(http.StatusCreated, httpResp.Code) + e.Equal(dbUser.Email, email) + e.Equal(dbUser.Password, hashedPasswd) +} + +func (e *AppTestSuite) TestAuthV1_SignUP_badrequest() { + tests := []struct { + name string + username string + email string + password string + }{ + {name: "all fiels empty", email: "", password: "", username: ""}, + { + name: "non valid email", + email: "email", + password: "password", + }, + { + name: "non valid password", + email: "test@test.com", + password: "12345", + username: "test", + }, + } + for _, t := range tests { + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/signup", + e.jsonify(apiv1AuthSignUpRequest{ + Username: t.username, + Email: t.email, + Password: t.password, + }), + ) + + e.Equal(http.StatusBadRequest, httpResp.Code) + } +} + +type apiv1AuthSignInRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type apiv1AuthSignInResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +func (e *AppTestSuite) TestAuthV1_SignIn() { + email := e.uuid() + "email@email.com" + password := "qwerty" + + uid := e.insertUserIntoDB("test", email, password) + + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/signin", + e.jsonify(apiv1AuthSignInRequest{ + Email: email, + Password: password, + }), + ) + + var body apiv1AuthSignInResponse + e.readBodyAndUnjsonify(httpResp.Body, &body) + + session := e.getLastUserSessionByUserID(uid) + parsedToken := e.parseJwtToken(body.AccessToken) + + e.Equal(http.StatusOK, httpResp.Code) + e.Equal(body.RefreshToken, session.RefreshToken) + e.Equal(parsedToken.UserID, uid.String()) +} + +func (e *AppTestSuite) TestAuthV1_SignIn_wrong() { + password := "password" + email := e.uuid() + "@test.com" + e.insertUserIntoDB(e.uuid(), email, "password") + + tests := []struct { + name string + email string + password string + }{ + { + name: "wrong email", + email: "wrong@emai.com", + password: password, + }, + { + name: "wrong password", + email: email, + password: "wrong-wrong", + }, + } + + for _, t := range tests { + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/signin", + e.jsonify(apiv1AuthSignInRequest{ + Email: t.email, + Password: t.password, + }), + ) + + e.Equal(http.StatusUnauthorized, httpResp.Code) + } +} + +type apiv1AuthRefreshTokensRequest struct { + RefreshToken string `json:"refresh_token"` +} + +func (e *AppTestSuite) TestAuthV1_RefreshTokens() { + uid, toks := e.createAndSingIn(e.uuid()+"@test.com", e.uuid(), "password") + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/refresh-tokens", + e.jsonify(apiv1AuthRefreshTokensRequest{ + RefreshToken: toks.RefreshToken, + }), + ) + + var body apiv1AuthSignInResponse + e.readBodyAndUnjsonify(httpResp.Body, &body) + + session := e.getLastUserSessionByUserID(uid) + parsedToken := e.parseJwtToken(body.AccessToken) + e.Equal(parsedToken.UserID, uid.String()) + + e.Equal(httpResp.Code, http.StatusOK) + e.NotEqual(toks.RefreshToken, body.RefreshToken) + e.Equal(body.RefreshToken, session.RefreshToken) +} + +func (e *AppTestSuite) TestAuthV1_RefreshTokens_wrong() { + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/refresh-tokens", + e.jsonify(apiv1AuthRefreshTokensRequest{ + RefreshToken: e.uuid(), + }), + ) + + e.Equal(httpResp.Code, http.StatusBadRequest) +} + +func (e *AppTestSuite) TestAuthV1_Logout() { + uid, toks := e.createAndSingIn(e.uuid()+"@test.com", e.uuid(), "password") + + session := e.getLastUserSessionByUserID(uid) + e.NotEmpty(session.RefreshToken) + + httpResp := e.httpRequest(http.MethodPost, "/api/v1/auth/logout", nil, toks.AccessToken) + + e.Equal(httpResp.Code, http.StatusNoContent) + + session = e.getLastUserSessionByUserID(uid) + e.Empty(session.RefreshToken) +} + +func (e *AppTestSuite) createAndSingIn( + email, username, password string, +) (uuid.UUID, apiv1AuthSignInResponse) { + uid := e.insertUserIntoDB(username, email, password) + httpResp := e.httpRequest( + http.MethodPost, + "/api/v1/auth/signin", + e.jsonify(apiv1AuthSignInRequest{ + Email: email, + Password: password, + }), + ) + + e.Equal(httpResp.Code, http.StatusOK) + + var body apiv1AuthSignInResponse + e.readBodyAndUnjsonify(httpResp.Body, &body) + + return uid, body +}
@@ -5,12 +5,16 @@ "context"
"fmt" "net/http" "testing" + "time" "github.com/gin-gonic/gin" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/pgx" "github.com/jackc/pgx/v5/stdlib" + "github.com/olexsmir/onasty/internal/hasher" + "github.com/olexsmir/onasty/internal/jwtutil" "github.com/olexsmir/onasty/internal/service/usersrv" + "github.com/olexsmir/onasty/internal/store/psql/sessionrepo" "github.com/olexsmir/onasty/internal/store/psql/userepo" "github.com/olexsmir/onasty/internal/store/psqlutil" httptransport "github.com/olexsmir/onasty/internal/transport/http"@@ -34,7 +38,9 @@
postgresDB *psqlutil.DB stopPostgres stopDBFunc - router http.Handler + router http.Handler + hasher hasher.Hasher + jwtTokenizer jwtutil.JWTTokenizer } )@@ -49,37 +55,42 @@
suite.Run(t, new(AppTestSuite)) } -func (s *AppTestSuite) SetupSuite() { - s.ctx = context.Background() - s.require = s.Require() +func (e *AppTestSuite) SetupSuite() { + e.ctx = context.Background() + e.require = e.Require() - db, stop, err := s.prepPostgres() - s.Require().NoError(err) + db, stop, err := e.prepPostgres() + e.Require().NoError(err) - s.postgresDB = db - s.stopPostgres = stop + e.postgresDB = db + e.stopPostgres = stop - s.initDeps() + e.initDeps() } -func (s *AppTestSuite) TearDownSuite() { - s.stopPostgres() +func (e *AppTestSuite) TearDownSuite() { + e.stopPostgres() } // initDeps initializes the dependencies for the app // and sets up the router for tests -func (s *AppTestSuite) initDeps() { - userepo := userepo.New(s.postgresDB) - usersrv := usersrv.New(userepo) +func (e *AppTestSuite) initDeps() { + e.hasher = hasher.NewSHA256Hasher("pass_salt") + e.jwtTokenizer = jwtutil.NewJWTUtil("jwt", time.Hour) + + sessionrepo := sessionrepo.New(e.postgresDB) + + userepo := userepo.New(e.postgresDB) + usersrv := usersrv.New(userepo, sessionrepo, e.hasher, e.jwtTokenizer) handler := httptransport.NewTransport(usersrv) - s.router = handler.Handler() + e.router = handler.Handler() } -func (s *AppTestSuite) prepPostgres() (*psqlutil.DB, stopDBFunc, error) { +func (e *AppTestSuite) prepPostgres() (*psqlutil.DB, stopDBFunc, error) { dbCredential := "testing" postgresContainer, err := postgres.RunContainer( - s.ctx, + e.ctx, testcontainers.WithImage("postgres:16-alpine"), postgres.WithUsername(dbCredential), postgres.WithPassword(dbCredential),@@ -87,46 +98,46 @@ postgres.WithDatabase(dbCredential),
testcontainers.WithWaitStrategy( wait.ForListeningPort("5432/tcp")), ) - s.require.NoError(err) + e.require.NoError(err) stop := func() { - err = postgresContainer.Terminate(s.ctx) - s.require.NoError(err) + err = postgresContainer.Terminate(e.ctx) + e.require.NoError(err) } // connect to the db - host, err := postgresContainer.Host(s.ctx) - s.require.NoError(err) + host, err := postgresContainer.Host(e.ctx) + e.require.NoError(err) - port, err := postgresContainer.MappedPort(s.ctx, "5432/tcp") - s.require.NoError(err) + port, err := postgresContainer.MappedPort(e.ctx, "5432/tcp") + e.require.NoError(err) db, err := psqlutil.Connect( - s.ctx, + e.ctx, fmt.Sprintf( //nolint:nosprintfhostport "postgres://%s:%s@%s:%s/%s", dbCredential, dbCredential, host, - port, + port.Port(), dbCredential, ), ) - s.require.NoError(err) + e.require.NoError(err) // run migrations sdb := stdlib.OpenDBFromPool(db.Pool) driver, err := pgx.WithInstance(sdb, &pgx.Config{}) - s.require.NoError(err) + e.require.NoError(err) m, err := migrate.NewWithDatabaseInstance( "file://../migrations/", "pgxv5", driver, ) - s.require.NoError(err) + e.require.NoError(err) err = m.Up() - s.require.NoError(err) + e.require.NoError(err) return db, stop, driver.Close() }
@@ -0,0 +1,68 @@
+package e2e + +import ( + "errors" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" + "github.com/olexsmir/onasty/internal/models" +) + +func (e *AppTestSuite) getUserFromDBByUsername(username string) models.User { + query, args, err := pgq. + Select("id", "username", "email", "password", "created_at", "last_login_at"). + From("users"). + Where(pgq.Eq{ + "username": username, + }). + SQL() + e.require.NoError(err) + + var user models.User + err = e.postgresDB.QueryRow(e.ctx, query, args...). + Scan(&user.ID, &user.Username, &user.Email, &user.Password, &user.CreatedAt, &user.LastLoginAt) + e.require.NoError(err) + + return user +} + +func (e *AppTestSuite) insertUserIntoDB(uname, email, passwd string) uuid.UUID { + p, err := e.hasher.Hash(passwd) + e.require.NoError(err) + + query, args, err := pgq. + Insert("users"). + Columns("username", "email", "password", "activated", "created_at", "last_login_at"). + Values(uname, email, p, true, time.Now(), time.Now()). + Returning("id"). + SQL() + e.require.NoError(err) + + var id uuid.UUID + err = e.postgresDB.QueryRow(e.ctx, query, args...).Scan(&id) + e.require.NoError(err) + + return id +} + +func (e *AppTestSuite) getLastUserSessionByUserID(uid uuid.UUID) models.Session { + query, args, err := pgq. + Select("refresh_token", "expires_at"). + From("sessions"). + Where(pgq.Eq{"user_id": uid.String()}). + OrderBy("expires_at DESC"). + SQL() + e.require.NoError(err) + + var session models.Session + err = e.postgresDB.QueryRow(e.ctx, query, args...). + Scan(&session.RefreshToken, &session.ExpiresAt) + if errors.Is(pgx.ErrNoRows, err) { + return models.Session{} + } + + e.require.NoError(err) + return session +}
@@ -0,0 +1,69 @@
+package e2e + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + "github.com/gofrs/uuid/v5" + "github.com/olexsmir/onasty/internal/jwtutil" +) + +// jsonify marshalls v into json and returns it as []byte +func (e *AppTestSuite) jsonify(v any) []byte { + r, err := json.Marshal(v) + e.require.NoError(err) + return r +} + +// readBodyAndUnjsonify reads body of `httptest.ResponseRecorder` and unmarshalls it into res +// +// Example: +// +// var res struct { message string `json:"message"` } +// readBodyAndUnjsonify(httpResp.Body, &res) +func (e *AppTestSuite) readBodyAndUnjsonify(b *bytes.Buffer, res any) { + respData, err := io.ReadAll(b) + e.require.NoError(err) + + err = json.Unmarshal(respData, &res) + e.require.NoError(err) +} + +// httpRequest sends http request to the server and returns `httptest.ResponseRecorder` +// conteny-type always set to application/json +func (e *AppTestSuite) httpRequest( + method, url string, //nolint:unparam // TODO: fix me later + body []byte, + accessToken ...string, +) *httptest.ResponseRecorder { + req, err := http.NewRequest(method, url, bytes.NewBuffer(body)) + e.require.NoError(err) + + req.Header.Set("Content-type", "application/json") + + if len(accessToken) == 1 { + req.Header.Set("Authorization", "Bearer "+accessToken[0]) + } + + resp := httptest.NewRecorder() + e.router.ServeHTTP(resp, req) + + return resp +} + +// uuid generates a new UUID and returns it as a string +func (e *AppTestSuite) uuid() string { + u, err := uuid.NewV4() + e.require.NoError(err) + return u.String() +} + +// parseJwtToken util func that parses jwt token and returns payload +func (e *AppTestSuite) parseJwtToken(t string) jwtutil.Payload { + r, err := e.jwtTokenizer.Parse(t) + e.require.NoError(err) + return r +}
@@ -5,7 +5,10 @@
require ( github.com/gin-gonic/gin v1.10.0 github.com/gofrs/uuid/v5 v5.0.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-migrate/migrate/v4 v4.17.1 + github.com/henvic/pgq v0.0.2 + github.com/jackc/pgconn v1.14.3 github.com/jackc/pgx-gofrs-uuid v0.0.0-20230224015001-1d428863c2e2 github.com/jackc/pgx/v5 v5.6.0 github.com/stretchr/testify v1.9.0@@ -47,7 +50,6 @@ github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -81,6 +81,8 @@ github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M=
github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4= github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=@@ -100,6 +102,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/henvic/pgq v0.0.2 h1:4q/G/cW7zpxpwq672Xuh7BkcKcXonZJ6b9kR8ub3EwQ= +github.com/henvic/pgq v0.0.2/go.mod h1:1Q6dKMwtbe2glBXlusJvNZnJrvgbwub/KcfiB/7UXA4= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@@ -2,11 +2,17 @@ package config
import ( "os" + "time" ) type Config struct { - AppEnv string - ServerPort string + AppEnv string + ServerPort string + PasswordSalt string + + JwtSigningKey string + JwtAccessTokenTTL time.Duration + JwtRefreshTokenTTL time.Duration LogLevel string LogFormat string@@ -16,11 +22,15 @@ }
func NewConfig() *Config { return &Config{ - AppEnv: getenvOrDefault("APP_ENV", "debug"), - ServerPort: getenvOrDefault("SERVER_PORT", "3000"), - LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), - LogFormat: getenvOrDefault("LOG_FORMAT", "json"), - PostgresDSN: getenvOrDefault("POSTGRESQL_DSN", ""), + AppEnv: getenvOrDefault("APP_ENV", "debug"), + ServerPort: getenvOrDefault("SERVER_PORT", "3000"), + PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""), + JwtSigningKey: getenvOrDefault("JWT_SIGNING_KEY", ""), + JwtAccessTokenTTL: mustParseDuration(getenvOrDefault("JWT_ACCESS_TOKEN_TTL", "15m")), + JwtRefreshTokenTTL: mustParseDuration(getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "15d")), + LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), + LogFormat: getenvOrDefault("LOG_FORMAT", "json"), + PostgresDSN: getenvOrDefault("POSTGRESQL_DSN", ""), } }@@ -34,3 +44,8 @@ return v
} return def } + +func mustParseDuration(dur string) time.Duration { + d, _ := time.ParseDuration(dur) + return d +}
@@ -0,0 +1,6 @@
+package dtos + +type TokensDTO struct { + Access string + Refresh string +}
@@ -0,0 +1,29 @@
+package dtos + +import ( + "time" + + "github.com/gofrs/uuid/v5" +) + +type UserDTO struct { + ID uuid.UUID + Username string + Email string + Password string + CreatedAt time.Time + LastLoginAt time.Time +} + +type CreateUserDTO struct { + Username string + Email string + Password string + CreatedAt time.Time + LastLoginAt time.Time +} + +type SignInDTO struct { + Email string + Password string +}
@@ -0,0 +1,6 @@
+package hasher + +type Hasher interface { + // Hash takes a string as input and returns its hash + Hash(string) (string, error) +}
@@ -0,0 +1,22 @@
+package hasher + +import ( + "crypto/sha256" + "encoding/hex" +) + +type SHA256Hasher struct { + salt string +} + +func NewSHA256Hasher(salt string) *SHA256Hasher { + return &SHA256Hasher{salt: salt} +} + +func (h *SHA256Hasher) Hash(inp string) (string, error) { + hash := sha256.New() + if _, err := hash.Write([]byte(inp)); err != nil { + return "", err + } + return hex.EncodeToString(hash.Sum([]byte(h.salt))), nil +}
@@ -0,0 +1,68 @@
+package jwtutil + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type JWTTokenizer interface { + // AccessToken generates a new access token with the given payload + AccessToken(pl Payload) (string, error) + + // RefreshToken generates a new refresh token + RefreshToken() (string, error) + + // Parse parses the token and returns the payload + Parse(token string) (Payload, error) +} + +type Payload struct { + UserID string +} + +var _ JWTTokenizer = (*JWTUtil)(nil) + +type JWTUtil struct { + signingKey string + accessTokenTTL time.Duration +} + +func NewJWTUtil(signingKey string, accessTokenTTL time.Duration) *JWTUtil { + return &JWTUtil{ + signingKey: signingKey, + accessTokenTTL: accessTokenTTL, + } +} + +func (j *JWTUtil) AccessToken(pl Payload) (string, error) { + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Subject: pl.UserID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.accessTokenTTL)), + }) + return tok.SignedString([]byte(j.signingKey)) +} + +func (j *JWTUtil) RefreshToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func (j *JWTUtil) Parse(token string) (Payload, error) { + var claims jwt.RegisteredClaims + _, err := jwt.ParseWithClaims(token, &claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("unexpected signing method") + } + return []byte(j.signingKey), nil + }) + return Payload{ + UserID: claims.Subject, + }, err +}
@@ -0,0 +1,17 @@
+package models + +import ( + "errors" + "time" + + "github.com/gofrs/uuid/v5" +) + +var ErrSessionNotFound = errors.New("user: session not found") + +type Session struct { + ID uuid.UUID + UserID uuid.UUID + RefreshToken string + ExpiresAt time.Time +}
@@ -0,0 +1,44 @@
+package models + +import ( + "errors" + "net/mail" + "time" + + "github.com/gofrs/uuid/v5" +) + +var ( + ErrUserEmailIsAlreadyInUse = errors.New("user: email is already in use") + ErrUsernameIsAlreadyInUse = errors.New("user: username is already in use") + + ErrUserNotFound = errors.New("user: not found") + ErrUserWrongCredentials = errors.New("user: wrong credentials") +) + +type User struct { + ID uuid.UUID + Username string + Email string + Password string + CreatedAt time.Time + LastLoginAt time.Time +} + +func (u User) Validate() error { + // NOTE: there's probably a better way to validate emails + _, err := mail.ParseAddress(u.Email) + if err != nil { + return errors.New("user: invalid email") + } + + if len(u.Password) < 6 { + return errors.New("user: password too short, minimum 6 chars") + } + + if len(u.Username) == 0 { + return errors.New("user: username is required") + } + + return nil +}
@@ -0,0 +1,70 @@
+package models + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUser_Validate(t *testing.T) { + tests := []struct { + name string + fail bool + + username string + email string + password string + }{ + { + name: "valid", + fail: false, + email: "test@example.org", + username: "iuserarchbtw", + password: "superhardasspassword", + }, + { + name: "all fields empty", + fail: true, + email: "", + username: "", + password: "", + }, + { + name: "invalid email", + fail: true, + email: "test", + username: "iuserarchbtw", + password: "superhardasspassword", + }, + { + name: "invalid password", + fail: true, + email: "test@example.org", + username: "iuserarchbtw", + password: "12345", + }, + { + name: "invalid username", + fail: true, + email: "test@example.org", + username: "", + password: "superhardasspassword", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := User{ + Username: tt.username, + Email: tt.email, + Password: tt.password, + }.Validate() + + if tt.fail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +}
@@ -1,22 +1,142 @@
package usersrv -import "github.com/olexsmir/onasty/internal/store/psql/userepo" +import ( + "context" + "errors" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/olexsmir/onasty/internal/dtos" + "github.com/olexsmir/onasty/internal/hasher" + "github.com/olexsmir/onasty/internal/jwtutil" + "github.com/olexsmir/onasty/internal/models" + "github.com/olexsmir/onasty/internal/store/psql/sessionrepo" + "github.com/olexsmir/onasty/internal/store/psql/userepo" +) type UserServicer interface { - SignUp() error + SignUp(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) + SignIn(ctx context.Context, inp dtos.SignInDTO) (dtos.TokensDTO, error) + RefreshTokens(ctx context.Context, refreshToken string) (dtos.TokensDTO, error) + Logout(ctx context.Context, userID uuid.UUID) error + + ParseToken(token string) (jwtutil.Payload, error) + CheckIfUserExists(ctx context.Context, userID uuid.UUID) (bool, error) } +var _ UserServicer = (*UserSrv)(nil) + type UserSrv struct { - store userepo.UserStorer + userstore userepo.UserStorer + sessionstore sessionrepo.SessionStorer + hasher hasher.Hasher + jwtTokenizer jwtutil.JWTTokenizer + + refreshTokenExpiredAt time.Time } -func New(store userepo.UserStorer) UserServicer { +func New( + userstore userepo.UserStorer, + sessionstore sessionrepo.SessionStorer, + hasher hasher.Hasher, + jwtTokenizer jwtutil.JWTTokenizer, +) UserServicer { return &UserSrv{ - store: store, + userstore: userstore, + sessionstore: sessionstore, + hasher: hasher, + jwtTokenizer: jwtTokenizer, } } -// type SignUp -func (s *UserSrv) SignUp() error { - return nil +func (u *UserSrv) SignUp(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) { + hashedPassword, err := u.hasher.Hash(inp.Password) + if err != nil { + return uuid.UUID{}, err + } + + return u.userstore.Create(ctx, dtos.CreateUserDTO{ + Username: inp.Username, + Email: inp.Email, + Password: hashedPassword, + CreatedAt: inp.CreatedAt, + LastLoginAt: inp.LastLoginAt, + }) +} + +func (u *UserSrv) SignIn(ctx context.Context, inp dtos.SignInDTO) (dtos.TokensDTO, error) { + hashedPassword, err := u.hasher.Hash(inp.Password) + if err != nil { + return dtos.TokensDTO{}, err + } + + user, err := u.userstore.GetUserByCredentials(ctx, inp.Email, hashedPassword) + if err != nil { + if errors.Is(err, models.ErrUserNotFound) { + return dtos.TokensDTO{}, models.ErrUserWrongCredentials + } + return dtos.TokensDTO{}, err + } + + tokens, err := u.getTokens(user.ID) + if err != nil { + return dtos.TokensDTO{}, err + } + + if err := u.sessionstore.Set(ctx, user.ID, tokens.Refresh, u.refreshTokenExpiredAt); err != nil { + return dtos.TokensDTO{}, err + } + + return dtos.TokensDTO{ + Access: tokens.Access, + Refresh: tokens.Refresh, + }, nil +} + +func (u *UserSrv) Logout(ctx context.Context, userID uuid.UUID) error { + return u.sessionstore.Delete(ctx, userID) +} + +func (u *UserSrv) RefreshTokens(ctx context.Context, rtoken string) (dtos.TokensDTO, error) { + userID, err := u.sessionstore.GetUserIDByRefreshToken(ctx, rtoken) + if err != nil { + return dtos.TokensDTO{}, err + } + + tokens, err := u.getTokens(userID) + if err != nil { + return dtos.TokensDTO{}, err + } + + err = u.sessionstore.Update(ctx, userID, rtoken, tokens.Refresh) + + return dtos.TokensDTO{ + Access: tokens.Access, + Refresh: tokens.Refresh, + }, err +} + +func (u *UserSrv) ParseToken(token string) (jwtutil.Payload, error) { + return u.jwtTokenizer.Parse(token) +} + +func (u UserSrv) CheckIfUserExists(ctx context.Context, id uuid.UUID) (bool, error) { + return u.userstore.CheckIfUserExists(ctx, id) +} + +func (u UserSrv) getTokens(userID uuid.UUID) (dtos.TokensDTO, error) { + accessToken, err := u.jwtTokenizer.AccessToken(jwtutil.Payload{UserID: userID.String()}) + if err != nil { + return dtos.TokensDTO{}, err + } + + refreshToken, err := u.jwtTokenizer.RefreshToken() + if err != nil { + return dtos.TokensDTO{}, err + } + + return dtos.TokensDTO{ + Access: accessToken, + Refresh: refreshToken, + }, err }
@@ -0,0 +1,106 @@
+package sessionrepo + +import ( + "context" + "errors" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" + "github.com/olexsmir/onasty/internal/models" + "github.com/olexsmir/onasty/internal/store/psqlutil" +) + +type SessionStorer interface { + Set(ctx context.Context, usedID uuid.UUID, refreshToken string, expiresAt time.Time) error + GetUserIDByRefreshToken(ctx context.Context, refreshToken string) (uuid.UUID, error) + Update(ctx context.Context, userID uuid.UUID, refreshToken string, newRefreshToken string) error + Delete(ctx context.Context, userID uuid.UUID) error +} + +var _ SessionStorer = (*SessionRepo)(nil) + +type SessionRepo struct { + db *psqlutil.DB +} + +func New(db *psqlutil.DB) SessionStorer { + return &SessionRepo{ + db: db, + } +} + +func (s *SessionRepo) Set( + ctx context.Context, + userID uuid.UUID, + refreshToken string, + expiresAt time.Time, +) error { + query, args, err := pgq. + Insert("sessions"). + Columns("user_id", "refresh_token", "expires_at"). + Values(userID, refreshToken, expiresAt). + SQL() + if err != nil { + return err + } + + _, err = s.db.Exec(ctx, query, args...) + return err +} + +func (s *SessionRepo) Update( + ctx context.Context, + userID uuid.UUID, + refreshToken string, + newRefreshToken string, +) error { + query := `--sql +update sessions +set refresh_token = $1 +where + user_id = $2 + and refresh_token = $3 + and expires_at < now() +` + + res, err := s.db.Exec(ctx, query, newRefreshToken, userID, refreshToken) + if res.RowsAffected() != 1 { + return models.ErrSessionNotFound + } + + return err +} + +func (s *SessionRepo) GetUserIDByRefreshToken( + ctx context.Context, + refreshToken string, +) (uuid.UUID, error) { + query, args, err := pgq. + Select("user_id"). + From("sessions"). + Where(pgq.Eq{"refresh_token": refreshToken}). + SQL() + if err != nil { + return uuid.UUID{}, err + } + + var userID uuid.UUID + err = s.db.QueryRow(ctx, query, args...).Scan(&userID) + if errors.Is(err, pgx.ErrNoRows) { + return uuid.UUID{}, models.ErrUserNotFound + } + + return userID, err +} + +func (s *SessionRepo) Delete(ctx context.Context, userID uuid.UUID) error { + query := `--sql +DELETE FROM sessions +WHERE user_id = $1 +` + + _, err := s.db.Exec(ctx, query, userID) + return err +}
@@ -1,26 +1,98 @@
package userepo import ( + "context" + "errors" + "github.com/gofrs/uuid/v5" + "github.com/henvic/pgq" + "github.com/jackc/pgx/v5" + "github.com/olexsmir/onasty/internal/dtos" + "github.com/olexsmir/onasty/internal/models" "github.com/olexsmir/onasty/internal/store/psqlutil" ) type UserStorer interface { - SignUp(inp SignUpInput) (uuid.UUID, error) + Create(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) + GetUserByCredentials(ctx context.Context, email, password string) (dtos.UserDTO, error) + + CheckIfUserExists(ctx context.Context, id uuid.UUID) (bool, error) } + +var _ UserStorer = (*UserRepo)(nil) type UserRepo struct { db *psqlutil.DB } -func New(db *psqlutil.DB) UserStorer { +func New(db *psqlutil.DB) *UserRepo { return &UserRepo{ db: db, } } -type SignUpInput struct{} +func (r *UserRepo) Create(ctx context.Context, inp dtos.CreateUserDTO) (uuid.UUID, error) { + query, args, err := pgq. + Insert("users"). + Columns("username", "email", "password", "created_at", "last_login_at"). + Values(inp.Username, inp.Email, inp.Password, inp.CreatedAt, inp.LastLoginAt). + Returning("id"). + SQL() + if err != nil { + return uuid.UUID{}, err + } -func (r *UserRepo) SignUp(_ SignUpInput) (uuid.UUID, error) { - return uuid.UUID{}, nil + var id uuid.UUID + err = r.db.QueryRow(ctx, query, args...).Scan(&id) + + // FIXME: somehow this does return errors but i can't errors.Is them in api layer + if psqlutil.IsDuplicateErr(err, "users_username_key") { + return uuid.UUID{}, models.ErrUsernameIsAlreadyInUse + } + + if psqlutil.IsDuplicateErr(err, "users_email_key") { + return uuid.UUID{}, models.ErrUserEmailIsAlreadyInUse + } + + return id, err +} + +func (r *UserRepo) GetUserByCredentials( + ctx context.Context, + email, password string, +) (dtos.UserDTO, error) { + query, args, err := pgq. + Select("id", "username", "email", "password", "created_at", "last_login_at"). + From("users"). + Where(pgq.Eq{ + "email": email, + "password": password, + }). + SQL() + if err != nil { + return dtos.UserDTO{}, err + } + + var user dtos.UserDTO + err = r.db.QueryRow(ctx, query, args...). + Scan(&user.ID, &user.Username, &user.Email, &user.Password, &user.CreatedAt, &user.LastLoginAt) + if errors.Is(err, pgx.ErrNoRows) { + return dtos.UserDTO{}, models.ErrUserNotFound + } + + return user, err +} + +func (r *UserRepo) CheckIfUserExists(ctx context.Context, id uuid.UUID) (bool, error) { + var exists bool + err := r.db.QueryRow( + ctx, + `SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)`, + id.String(), + ).Scan(&exists) + if errors.Is(err, pgx.ErrNoRows) { + return false, models.ErrUserNotFound + } + + return exists, err }
@@ -2,7 +2,9 @@ package psqlutil
import ( "context" + "errors" + "github.com/jackc/pgconn" pgxuuid "github.com/jackc/pgx-gofrs-uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool"@@ -39,3 +41,16 @@ func (db *DB) Close() error {
db.Pool.Close() return nil } + +// IsDuplicateErr function that checks if the error is a duplicate key violation. +func IsDuplicateErr(err error, constraintName ...string) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + if len(constraintName) == 0 || len(constraintName) == 1 { + return pgErr.Code == "23505" && // unique_violation + pgErr.ConstraintName == constraintName[0] + } + return pgErr.Code == "23505" // unique_violation + } + return false +}
@@ -6,12 +6,12 @@ "github.com/olexsmir/onasty/internal/service/usersrv"
) type APIV1 struct { - userSrv usersrv.UserServicer + usersrv usersrv.UserServicer } func NewAPIV1(us usersrv.UserServicer) *APIV1 { return &APIV1{ - userSrv: us, + usersrv: us, } }
@@ -2,11 +2,15 @@ package apiv1
import ( "net/http" + "time" "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/dtos" + "github.com/olexsmir/onasty/internal/models" ) type signUpRequest struct { + Username string `json:"username"` Email string `json:"email"` Password string `json:"password"` }@@ -17,10 +21,94 @@ if err := c.ShouldBindJSON(&req); err != nil {
newError(c, http.StatusBadRequest, "invalid request") return } + + user := models.User{ + Username: req.Username, + Email: req.Email, + Password: req.Password, + CreatedAt: time.Now(), + LastLoginAt: time.Now(), + } + if err := user.Validate(); err != nil { + // TODO: find a way to return all errors at once + newErrorStatus(c, http.StatusBadRequest, err.Error()) + return + } + + if _, err := a.usersrv.SignUp(c.Request.Context(), dtos.CreateUserDTO{ + Username: user.Username, + Email: user.Email, + Password: user.Password, + CreatedAt: user.CreatedAt, + LastLoginAt: user.LastLoginAt, + }); err != nil { + errorResponse(c, err) + return + } + + c.Status(http.StatusCreated) } -func (a *APIV1) signInHandler(_ *gin.Context) {} +type signInRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} -func (a *APIV1) refreshTokensHandler(_ *gin.Context) {} +type signInResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} -func (a *APIV1) logOutHandler(_ *gin.Context) {} +func (a *APIV1) signInHandler(c *gin.Context) { + var req signInRequest + if err := c.ShouldBindJSON(&req); err != nil { + newError(c, http.StatusBadRequest, "invalid request") + return + } + + toks, err := a.usersrv.SignIn(c.Request.Context(), dtos.SignInDTO{ + Email: req.Email, + Password: req.Password, + }) + if err != nil { + errorResponse(c, err) + return + } + + c.JSON(http.StatusOK, signInResponse{ + AccessToken: toks.Access, + RefreshToken: toks.Refresh, + }) +} + +type refreshTokenRequest struct { + RefreshToken string `json:"refresh_token"` +} + +func (a *APIV1) refreshTokensHandler(c *gin.Context) { + var req refreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + newError(c, http.StatusBadRequest, "invalid request") + return + } + + toks, err := a.usersrv.RefreshTokens(c.Request.Context(), req.RefreshToken) + if err != nil { + errorResponse(c, err) + return + } + + c.JSON(http.StatusOK, signInResponse{ + AccessToken: toks.Access, + RefreshToken: toks.Refresh, + }) +} + +func (a *APIV1) logOutHandler(c *gin.Context) { + if err := a.usersrv.Logout(c.Request.Context(), getUserID(c)); err != nil { + errorResponse(c, err) + return + } + + c.Status(http.StatusNoContent) +}
@@ -1,8 +1,125 @@
package apiv1 -import "github.com/gin-gonic/gin" +import ( + "context" + "errors" + "strings" + + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid/v5" + "github.com/olexsmir/onasty/internal/service/usersrv" +) + +var ErrUnauthorized = errors.New("unauthorized") + +const userIDCtxKey = "userID" + +func (a *APIV1) authorizedMiddleware(c *gin.Context) { + token, ok := getTokenFromAuthHeaders(c) + if !ok { + errorResponse(c, ErrUnauthorized) + return + } + + ok, err := checkIfUserIsReal(c.Request.Context(), token, a.usersrv) + if err != nil { + errorResponse(c, err) + return + } + + if !ok { + errorResponse(c, ErrUnauthorized) + return + } + + if err := saveUserIDToCtx(c, a.usersrv, token); err != nil { + errorResponse(c, err) + return + } + + c.Next() +} + +//nolint:unused // TODO: remove me later +func (a *APIV1) couldBeAuthorizedMiddleware(c *gin.Context) { + token, ok := getTokenFromAuthHeaders(c) + if ok { + ok, err := checkIfUserIsReal(c.Request.Context(), token, a.usersrv) + if err != nil { + errorResponse(c, err) + return + } + + if !ok { + errorResponse(c, ErrUnauthorized) + return + } + + if err := saveUserIDToCtx(c, a.usersrv, token); err != nil { + newInternalError(c, err) + return + } + } + + c.Next() +} + +//nolint:unused // TODO: remove me later +func (a *APIV1) isUserAuthorized(c *gin.Context) bool { + return !getUserID(c).IsNil() +} + +func getTokenFromAuthHeaders(c *gin.Context) (token string, ok bool) { //nolint:nonamedreturns + header := c.GetHeader("Authorization") + if header == "" { + return "", false + } + + headerParts := strings.Split(header, " ") + if len(headerParts) != 2 && headerParts[0] != "Bearer" { + return "", false + } + + if len(headerParts[1]) == 0 { + return "", false + } + + return headerParts[1], true +} + +func saveUserIDToCtx(c *gin.Context, us usersrv.UserServicer, token string) error { + pl, err := us.ParseToken(token) + if err != nil { + return err + } + + c.Set(userIDCtxKey, pl.UserID) -func (a *APIV1) authorizedMiddleware(_ *gin.Context) {} + return nil +} -func (a *APIV1) couldBeAuthorizedMiddleware(_ *gin.Context) { //nolint:unused +// getUserId returns userId from the context +// getting user id is only possible if user is authorized +func getUserID(c *gin.Context) uuid.UUID { + userID, exists := c.Get(userIDCtxKey) + if !exists { + return uuid.Nil + } + return uuid.Must(uuid.FromString(userID.(string))) +} + +func checkIfUserIsReal( + ctx context.Context, + accessToken string, + us usersrv.UserServicer, +) (bool, error) { + parsedToken, err := us.ParseToken(accessToken) + if err != nil { + return false, err + } + + return us.CheckIfUserExists( + ctx, + uuid.Must(uuid.FromString(parsedToken.UserID)), + ) }
@@ -1,19 +1,47 @@
package apiv1 import ( + "errors" "log/slog" "net/http" "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/models" ) type response struct { Message string `json:"message"` } -func newError(c *gin.Context, status int, msg string) { - slog.With("status", status).Error(msg) +func errorResponse(c *gin.Context, err error) { + if errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || + errors.Is(err, models.ErrUsernameIsAlreadyInUse) { + newError(c, http.StatusBadRequest, err.Error()) + return + } + + if errors.Is(err, models.ErrUserNotFound) { + newErrorStatus(c, http.StatusBadRequest, err.Error()) + return + } + + if errors.Is(err, ErrUnauthorized) || + errors.Is(err, models.ErrUserWrongCredentials) { + newErrorStatus(c, http.StatusUnauthorized, err.Error()) + return + } + + newInternalError(c, err) +} + +func newError(c *gin.Context, status int, msg string) { //nolint:unparam // TODO: remove me later + slog.Error(msg, "status", status) c.AbortWithStatusJSON(status, response{msg}) +} + +func newErrorStatus(c *gin.Context, status int, msg string) { + slog.Error(msg, "status", status) + c.AbortWithStatus(status) } func newInternalError(c *gin.Context, err error, msg ...string) {
@@ -7,6 +7,7 @@
"github.com/gin-gonic/gin" ) +// TODO: include requiest id func (t *Transport) logger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now()
@@ -0,0 +1,9 @@
+create table users ( + id uuid primary key default uuid_generate_v4(), + username varchar(255) not null unique, + email varchar(255) not null unique, + password varchar(255) not null, + activated boolean not null default false, + created_at timestamptz not null default now(), + last_login_at timestamptz not null default now() +);
@@ -0,0 +1,6 @@
+create table sessions ( + id uuid primary key default uuid_generate_v4(), + user_id uuid references users (id), + refresh_token varchar(255) not null unique, + expires_at timestamptz not null +);