all repos

onasty @ 51d3b53

a one-time notes service
10 files changed, 74 insertions(+), 17 deletions(-)
fix: oauth state (#128)

* pass app env to handlers

* fix(oauth): set cookie with state on redirect

* feat(oauth): check oatuh state in callback

* fixup! pass app env to handlers

* fixup! fix(oauth): set cookie with state on redirect

* fix(userrepo): do update of user's identity if they were logged in before

* refactor: move oauth state cookie to sep var

* refactor: update age of the oauth cookie

* feat: set the domain to the oauth state cookie
Author: Smirnov Oleksandr ss2316544@gmail.com
Committed by: GitHub noreply@github.com
Committed at: 2025-06-07 16:45:46 +0300
Parent: 9f4dd4b
M cmd/api/main.go

@@ -57,7 +57,7 @@

slog.SetDefault(logger) // semi dev mode - if !cfg.IsDevMode() { + if !cfg.AppEnv.IsDevMode() { gin.SetMode(gin.ReleaseMode) }

@@ -129,6 +129,8 @@

handler := httptransport.NewTransport( usersrv, notesrv, + cfg.AppEnv, + cfg.AppURL, cfg.CORSAllowedOrigins, cfg.CORSMaxAge, rateLimiterConfig,
M e2e/e2e_test.go

@@ -140,6 +140,8 @@

handler := httptransport.NewTransport( usersrv, notesrv, + cfg.AppEnv, + cfg.AppURL, cfg.CORSAllowedOrigins, cfg.CORSMaxAge, ratelimitCfg,
M internal/config/config.go

@@ -8,8 +8,14 @@ "strings"

"time" ) +type Environment string + +func (e Environment) IsDevMode() bool { + return e == "debug" || e == "test" +} + type Config struct { - AppEnv string + AppEnv Environment AppURL string NatsURL string

@@ -61,7 +67,7 @@ }

func NewConfig() *Config { return &Config{ - AppEnv: getenvOrDefault("APP_ENV", "debug"), + AppEnv: Environment(getenvOrDefault("APP_ENV", "debug")), AppURL: getenvOrDefault("APP_URL", ""), NatsURL: getenvOrDefault("NATS_URL", ""),

@@ -114,10 +120,6 @@ RateLimiterRPS: mustGetenvOrDefaultInt("RATELIMITER_RPS", 100),

RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10), RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")), } -} - -func (c *Config) IsDevMode() bool { - return c.AppEnv == "debug" || c.AppEnv == "test" } func getenvOrDefault(key, def string) string {
M internal/dtos/user.go

@@ -30,6 +30,11 @@ Token string

NewPassword string } +type OAuthRedirect struct { + URL string + State string +} + type Tokens struct { Access string Refresh string
M internal/service/usersrv/oauth.go

@@ -19,14 +19,22 @@ googleProvider = "google"

githubProvider = "github" ) -func (u *UserSrv) GetOAuthURL(providerName string) (string, error) { +func (u *UserSrv) GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) { + state := uuid.Must(uuid.NewV4()).String() + switch providerName { case googleProvider: - return u.googleOauth.GetAuthURL(""), nil + return dtos.OAuthRedirect{ + URL: u.googleOauth.GetAuthURL(state), + State: state, + }, nil case githubProvider: - return u.githubOauth.GetAuthURL(""), nil + return dtos.OAuthRedirect{ + URL: u.githubOauth.GetAuthURL(state), + State: state, + }, nil default: - return "", ErrProviderNotSupported + return dtos.OAuthRedirect{}, ErrProviderNotSupported } }
M internal/service/usersrv/usersrv.go

@@ -30,7 +30,7 @@ ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ChangeUserPassword) error

RequestPasswordReset(ctx context.Context, inp dtos.RequestResetPassword) error ResetPassword(ctx context.Context, inp dtos.ResetPassword) error - GetOAuthURL(providerName string) (string, error) + GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error) Verify(ctx context.Context, verificationKey string) error
M internal/store/psql/userepo/userepo.go

@@ -157,8 +157,12 @@ userID uuid.UUID,

provider, providerID string, ) error { query := `--sql - insert into oauth_identities (user_id, provider, provider_id) - values ($1, $2, $3)` +insert into oauth_identities (user_id, provider, provider_id) +values ($1, $2, $3) +on conflict (provider, provider_id) do update +set user_id = $1, + provider = $2, + provider_id = $3` _, err := r.db.Exec(ctx, query, userID, provider, providerID) return err
M internal/transport/http/apiv1/apiv1.go

@@ -2,6 +2,7 @@ package apiv1

import ( "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/config" "github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" )

@@ -9,15 +10,21 @@

type APIV1 struct { usersrv usersrv.UserServicer notesrv notesrv.NoteServicer + env config.Environment + domain string } func NewAPIV1( us usersrv.UserServicer, ns notesrv.NoteServicer, + env config.Environment, + domain string, ) *APIV1 { return &APIV1{ usersrv: us, notesrv: ns, + env: env, + domain: domain, } }
M internal/transport/http/apiv1/auth.go

@@ -198,17 +198,36 @@

c.Status(http.StatusOK) } +const oatuhStateCookie = "oauth_state" + func (a *APIV1) oauthLoginHandler(c *gin.Context) { - url, err := a.usersrv.GetOAuthURL(c.Param("provider")) + redirectInfo, err := a.usersrv.GetOAuthURL(c.Param("provider")) if err != nil { errorResponse(c, err) return } - c.Redirect(http.StatusSeeOther, url) + c.SetCookie( + oatuhStateCookie, + redirectInfo.State, + int(time.Minute.Seconds()), + "/", + a.domain, + !a.env.IsDevMode(), + true, + ) + + c.Redirect(http.StatusSeeOther, redirectInfo.URL) } func (a *APIV1) oauthCallbackHandler(c *gin.Context) { + state := c.Query("state") + storedState, err := c.Cookie(oatuhStateCookie) + if err != nil || state != storedState { + newError(c, http.StatusBadRequest, "invalid oauth state") + return + } + tokens, err := a.usersrv.HandleOAuthLogin( c.Request.Context(), c.Param("provider"),
M internal/transport/http/http.go

@@ -5,6 +5,7 @@ "net/http"

"time" "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/config" "github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" "github.com/olexsmir/onasty/internal/transport/http/apiv1"

@@ -16,6 +17,9 @@ type Transport struct {

usersrv usersrv.UserServicer notesrv notesrv.NoteServicer + env config.Environment + domain string + corsAllowedOrigins []string corsMaxAge time.Duration ratelimitCfg ratelimit.Config

@@ -24,6 +28,8 @@

func NewTransport( us usersrv.UserServicer, ns notesrv.NoteServicer, + env config.Environment, + domain string, corsAllowedOrigins []string, corsMaxAge time.Duration, ratelimitCfg ratelimit.Config,

@@ -31,6 +37,8 @@ ) *Transport {

return &Transport{ usersrv: us, notesrv: ns, + env: env, + domain: domain, corsAllowedOrigins: corsAllowedOrigins, corsMaxAge: corsMaxAge, ratelimitCfg: ratelimitCfg,

@@ -49,7 +57,7 @@ )

api := r.Group("/api") api.GET("/ping", t.pingHandler) - apiv1.NewAPIV1(t.usersrv, t.notesrv).Routes(api.Group("/v1")) + apiv1.NewAPIV1(t.usersrv, t.notesrv, t.env, t.domain).Routes(api.Group("/v1")) return r.Handler() }