10 files changed,
74 insertions(+),
17 deletions(-)
Author:
Smirnov Oleksandr
ss2316544@gmail.com
Committed by:
GitHub
noreply@github.com
Committed at:
2025-06-07 16:45:46 +0300
Parent:
9f4dd4b
jump to
M
cmd/api/main.go
··· 57 57 slog.SetDefault(logger) 58 58 59 59 // semi dev mode 60 - if !cfg.IsDevMode() { 60 + if !cfg.AppEnv.IsDevMode() { 61 61 gin.SetMode(gin.ReleaseMode) 62 62 } 63 63 ··· 129 129 handler := httptransport.NewTransport( 130 130 usersrv, 131 131 notesrv, 132 + cfg.AppEnv, 133 + cfg.AppURL, 132 134 cfg.CORSAllowedOrigins, 133 135 cfg.CORSMaxAge, 134 136 rateLimiterConfig,
M
internal/config/config.go
··· 8 8 "time" 9 9 ) 10 10 11 +type Environment string 12 + 13 +func (e Environment) IsDevMode() bool { 14 + return e == "debug" || e == "test" 15 +} 16 + 11 17 type Config struct { 12 - AppEnv string 18 + AppEnv Environment 13 19 AppURL string 14 20 NatsURL string 15 21 ··· 61 67 62 68 func NewConfig() *Config { 63 69 return &Config{ 64 - AppEnv: getenvOrDefault("APP_ENV", "debug"), 70 + AppEnv: Environment(getenvOrDefault("APP_ENV", "debug")), 65 71 AppURL: getenvOrDefault("APP_URL", ""), 66 72 NatsURL: getenvOrDefault("NATS_URL", ""), 67 73 ··· 114 120 RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10), 115 121 RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")), 116 122 } 117 -} 118 - 119 -func (c *Config) IsDevMode() bool { 120 - return c.AppEnv == "debug" || c.AppEnv == "test" 121 123 } 122 124 123 125 func getenvOrDefault(key, def string) string {
M
internal/service/usersrv/oauth.go
··· 19 19 githubProvider = "github" 20 20 ) 21 21 22 -func (u *UserSrv) GetOAuthURL(providerName string) (string, error) { 22 +func (u *UserSrv) GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) { 23 + state := uuid.Must(uuid.NewV4()).String() 24 + 23 25 switch providerName { 24 26 case googleProvider: 25 - return u.googleOauth.GetAuthURL(""), nil 27 + return dtos.OAuthRedirect{ 28 + URL: u.googleOauth.GetAuthURL(state), 29 + State: state, 30 + }, nil 26 31 case githubProvider: 27 - return u.githubOauth.GetAuthURL(""), nil 32 + return dtos.OAuthRedirect{ 33 + URL: u.githubOauth.GetAuthURL(state), 34 + State: state, 35 + }, nil 28 36 default: 29 - return "", ErrProviderNotSupported 37 + return dtos.OAuthRedirect{}, ErrProviderNotSupported 30 38 } 31 39 } 32 40
M
internal/service/usersrv/usersrv.go
··· 30 30 RequestPasswordReset(ctx context.Context, inp dtos.RequestResetPassword) error 31 31 ResetPassword(ctx context.Context, inp dtos.ResetPassword) error 32 32 33 - GetOAuthURL(providerName string) (string, error) 33 + GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) 34 34 HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error) 35 35 36 36 Verify(ctx context.Context, verificationKey string) error
M
internal/store/psql/userepo/userepo.go
··· 157 157 provider, providerID string, 158 158 ) error { 159 159 query := `--sql 160 - insert into oauth_identities (user_id, provider, provider_id) 161 - values ($1, $2, $3)` 160 +insert into oauth_identities (user_id, provider, provider_id) 161 +values ($1, $2, $3) 162 +on conflict (provider, provider_id) do update 163 +set user_id = $1, 164 + provider = $2, 165 + provider_id = $3` 162 166 163 167 _, err := r.db.Exec(ctx, query, userID, provider, providerID) 164 168 return err
M
internal/transport/http/apiv1/apiv1.go
··· 2 2 3 3 import ( 4 4 "github.com/gin-gonic/gin" 5 + "github.com/olexsmir/onasty/internal/config" 5 6 "github.com/olexsmir/onasty/internal/service/notesrv" 6 7 "github.com/olexsmir/onasty/internal/service/usersrv" 7 8 ) ··· 9 10 type APIV1 struct { 10 11 usersrv usersrv.UserServicer 11 12 notesrv notesrv.NoteServicer 13 + env config.Environment 14 + domain string 12 15 } 13 16 14 17 func NewAPIV1( 15 18 us usersrv.UserServicer, 16 19 ns notesrv.NoteServicer, 20 + env config.Environment, 21 + domain string, 17 22 ) *APIV1 { 18 23 return &APIV1{ 19 24 usersrv: us, 20 25 notesrv: ns, 26 + env: env, 27 + domain: domain, 21 28 } 22 29 } 23 30
M
internal/transport/http/apiv1/auth.go
··· 198 198 c.Status(http.StatusOK) 199 199 } 200 200 201 +const oatuhStateCookie = "oauth_state" 202 + 201 203 func (a *APIV1) oauthLoginHandler(c *gin.Context) { 202 - url, err := a.usersrv.GetOAuthURL(c.Param("provider")) 204 + redirectInfo, err := a.usersrv.GetOAuthURL(c.Param("provider")) 203 205 if err != nil { 204 206 errorResponse(c, err) 205 207 return 206 208 } 207 209 208 - c.Redirect(http.StatusSeeOther, url) 210 + c.SetCookie( 211 + oatuhStateCookie, 212 + redirectInfo.State, 213 + int(time.Minute.Seconds()), 214 + "/", 215 + a.domain, 216 + !a.env.IsDevMode(), 217 + true, 218 + ) 219 + 220 + c.Redirect(http.StatusSeeOther, redirectInfo.URL) 209 221 } 210 222 211 223 func (a *APIV1) oauthCallbackHandler(c *gin.Context) { 224 + state := c.Query("state") 225 + storedState, err := c.Cookie(oatuhStateCookie) 226 + if err != nil || state != storedState { 227 + newError(c, http.StatusBadRequest, "invalid oauth state") 228 + return 229 + } 230 + 212 231 tokens, err := a.usersrv.HandleOAuthLogin( 213 232 c.Request.Context(), 214 233 c.Param("provider"),
M
internal/transport/http/http.go
··· 5 5 "time" 6 6 7 7 "github.com/gin-gonic/gin" 8 + "github.com/olexsmir/onasty/internal/config" 8 9 "github.com/olexsmir/onasty/internal/service/notesrv" 9 10 "github.com/olexsmir/onasty/internal/service/usersrv" 10 11 "github.com/olexsmir/onasty/internal/transport/http/apiv1" ··· 16 17 usersrv usersrv.UserServicer 17 18 notesrv notesrv.NoteServicer 18 19 20 + env config.Environment 21 + domain string 22 + 19 23 corsAllowedOrigins []string 20 24 corsMaxAge time.Duration 21 25 ratelimitCfg ratelimit.Config ··· 24 28 func NewTransport( 25 29 us usersrv.UserServicer, 26 30 ns notesrv.NoteServicer, 31 + env config.Environment, 32 + domain string, 27 33 corsAllowedOrigins []string, 28 34 corsMaxAge time.Duration, 29 35 ratelimitCfg ratelimit.Config, ··· 31 37 return &Transport{ 32 38 usersrv: us, 33 39 notesrv: ns, 40 + env: env, 41 + domain: domain, 34 42 corsAllowedOrigins: corsAllowedOrigins, 35 43 corsMaxAge: corsMaxAge, 36 44 ratelimitCfg: ratelimitCfg, ··· 49 57 50 58 api := r.Group("/api") 51 59 api.GET("/ping", t.pingHandler) 52 - apiv1.NewAPIV1(t.usersrv, t.notesrv).Routes(api.Group("/v1")) 60 + apiv1.NewAPIV1(t.usersrv, t.notesrv, t.env, t.domain).Routes(api.Group("/v1")) 53 61 54 62 return r.Handler() 55 63 }