10 files changed,
74 insertions(+),
17 deletions(-)
Author:
Smirnov Oleksandr
ss2316544@gmail.com
Committed by:
GitHub
noreply@github.com
Committed at:
2025-06-07 16:45:46 +0300
Parent:
9f4dd4b
jump to
M
cmd/api/main.go
@@ -57,7 +57,7 @@
slog.SetDefault(logger) // semi dev mode - if !cfg.IsDevMode() { + if !cfg.AppEnv.IsDevMode() { gin.SetMode(gin.ReleaseMode) }@@ -129,6 +129,8 @@
handler := httptransport.NewTransport( usersrv, notesrv, + cfg.AppEnv, + cfg.AppURL, cfg.CORSAllowedOrigins, cfg.CORSMaxAge, rateLimiterConfig,
M
e2e/e2e_test.go
@@ -140,6 +140,8 @@
handler := httptransport.NewTransport( usersrv, notesrv, + cfg.AppEnv, + cfg.AppURL, cfg.CORSAllowedOrigins, cfg.CORSMaxAge, ratelimitCfg,
M
internal/config/config.go
@@ -8,8 +8,14 @@ "strings"
"time" ) +type Environment string + +func (e Environment) IsDevMode() bool { + return e == "debug" || e == "test" +} + type Config struct { - AppEnv string + AppEnv Environment AppURL string NatsURL string@@ -61,7 +67,7 @@ }
func NewConfig() *Config { return &Config{ - AppEnv: getenvOrDefault("APP_ENV", "debug"), + AppEnv: Environment(getenvOrDefault("APP_ENV", "debug")), AppURL: getenvOrDefault("APP_URL", ""), NatsURL: getenvOrDefault("NATS_URL", ""),@@ -114,10 +120,6 @@ RateLimiterRPS: mustGetenvOrDefaultInt("RATELIMITER_RPS", 100),
RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10), RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")), } -} - -func (c *Config) IsDevMode() bool { - return c.AppEnv == "debug" || c.AppEnv == "test" } func getenvOrDefault(key, def string) string {
M
internal/dtos/user.go
@@ -30,6 +30,11 @@ Token string
NewPassword string } +type OAuthRedirect struct { + URL string + State string +} + type Tokens struct { Access string Refresh string
M
internal/service/usersrv/oauth.go
@@ -19,14 +19,22 @@ googleProvider = "google"
githubProvider = "github" ) -func (u *UserSrv) GetOAuthURL(providerName string) (string, error) { +func (u *UserSrv) GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) { + state := uuid.Must(uuid.NewV4()).String() + switch providerName { case googleProvider: - return u.googleOauth.GetAuthURL(""), nil + return dtos.OAuthRedirect{ + URL: u.googleOauth.GetAuthURL(state), + State: state, + }, nil case githubProvider: - return u.githubOauth.GetAuthURL(""), nil + return dtos.OAuthRedirect{ + URL: u.githubOauth.GetAuthURL(state), + State: state, + }, nil default: - return "", ErrProviderNotSupported + return dtos.OAuthRedirect{}, ErrProviderNotSupported } }
M
internal/service/usersrv/usersrv.go
@@ -30,7 +30,7 @@ ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ChangeUserPassword) error
RequestPasswordReset(ctx context.Context, inp dtos.RequestResetPassword) error ResetPassword(ctx context.Context, inp dtos.ResetPassword) error - GetOAuthURL(providerName string) (string, error) + GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error) Verify(ctx context.Context, verificationKey string) error
M
internal/store/psql/userepo/userepo.go
@@ -157,8 +157,12 @@ userID uuid.UUID,
provider, providerID string, ) error { query := `--sql - insert into oauth_identities (user_id, provider, provider_id) - values ($1, $2, $3)` +insert into oauth_identities (user_id, provider, provider_id) +values ($1, $2, $3) +on conflict (provider, provider_id) do update +set user_id = $1, + provider = $2, + provider_id = $3` _, err := r.db.Exec(ctx, query, userID, provider, providerID) return err
M
internal/transport/http/apiv1/apiv1.go
@@ -2,6 +2,7 @@ package apiv1
import ( "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/config" "github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" )@@ -9,15 +10,21 @@
type APIV1 struct { usersrv usersrv.UserServicer notesrv notesrv.NoteServicer + env config.Environment + domain string } func NewAPIV1( us usersrv.UserServicer, ns notesrv.NoteServicer, + env config.Environment, + domain string, ) *APIV1 { return &APIV1{ usersrv: us, notesrv: ns, + env: env, + domain: domain, } }
M
internal/transport/http/apiv1/auth.go
@@ -198,17 +198,36 @@
c.Status(http.StatusOK) } +const oatuhStateCookie = "oauth_state" + func (a *APIV1) oauthLoginHandler(c *gin.Context) { - url, err := a.usersrv.GetOAuthURL(c.Param("provider")) + redirectInfo, err := a.usersrv.GetOAuthURL(c.Param("provider")) if err != nil { errorResponse(c, err) return } - c.Redirect(http.StatusSeeOther, url) + c.SetCookie( + oatuhStateCookie, + redirectInfo.State, + int(time.Minute.Seconds()), + "/", + a.domain, + !a.env.IsDevMode(), + true, + ) + + c.Redirect(http.StatusSeeOther, redirectInfo.URL) } func (a *APIV1) oauthCallbackHandler(c *gin.Context) { + state := c.Query("state") + storedState, err := c.Cookie(oatuhStateCookie) + if err != nil || state != storedState { + newError(c, http.StatusBadRequest, "invalid oauth state") + return + } + tokens, err := a.usersrv.HandleOAuthLogin( c.Request.Context(), c.Param("provider"),
M
internal/transport/http/http.go
@@ -5,6 +5,7 @@ "net/http"
"time" "github.com/gin-gonic/gin" + "github.com/olexsmir/onasty/internal/config" "github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" "github.com/olexsmir/onasty/internal/transport/http/apiv1"@@ -16,6 +17,9 @@ type Transport struct {
usersrv usersrv.UserServicer notesrv notesrv.NoteServicer + env config.Environment + domain string + corsAllowedOrigins []string corsMaxAge time.Duration ratelimitCfg ratelimit.Config@@ -24,6 +28,8 @@
func NewTransport( us usersrv.UserServicer, ns notesrv.NoteServicer, + env config.Environment, + domain string, corsAllowedOrigins []string, corsMaxAge time.Duration, ratelimitCfg ratelimit.Config,@@ -31,6 +37,8 @@ ) *Transport {
return &Transport{ usersrv: us, notesrv: ns, + env: env, + domain: domain, corsAllowedOrigins: corsAllowedOrigins, corsMaxAge: corsMaxAge, ratelimitCfg: ratelimitCfg,@@ -49,7 +57,7 @@ )
api := r.Group("/api") api.GET("/ping", t.pingHandler) - apiv1.NewAPIV1(t.usersrv, t.notesrv).Routes(api.Group("/v1")) + apiv1.NewAPIV1(t.usersrv, t.notesrv, t.env, t.domain).Routes(api.Group("/v1")) return r.Handler() }