all repos

onasty @ 51d3b53b0c8469cfd0ca12bb2a505d7fe77d228e

a one-time notes service
10 files changed, 74 insertions(+), 17 deletions(-)
fix: oauth state (#128)

* pass app env to handlers

* fix(oauth): set cookie with state on redirect

* feat(oauth): check oatuh state in callback

* fixup! pass app env to handlers

* fixup! fix(oauth): set cookie with state on redirect

* fix(userrepo): do update of user's identity if they were logged in before

* refactor: move oauth state cookie to sep var

* refactor: update age of the oauth cookie

* feat: set the domain to the oauth state cookie
Author: Smirnov Oleksandr ss2316544@gmail.com
Committed by: GitHub noreply@github.com
Committed at: 2025-06-07 16:45:46 +0300
Parent: 9f4dd4b
M cmd/api/main.go
···
        57
        57
         	slog.SetDefault(logger)

      
        58
        58
         

      
        59
        59
         	// semi dev mode

      
        60
        
        -	if !cfg.IsDevMode() {

      
        
        60
        +	if !cfg.AppEnv.IsDevMode() {

      
        61
        61
         		gin.SetMode(gin.ReleaseMode)

      
        62
        62
         	}

      
        63
        63
         

      ···
        129
        129
         	handler := httptransport.NewTransport(

      
        130
        130
         		usersrv,

      
        131
        131
         		notesrv,

      
        
        132
        +		cfg.AppEnv,

      
        
        133
        +		cfg.AppURL,

      
        132
        134
         		cfg.CORSAllowedOrigins,

      
        133
        135
         		cfg.CORSMaxAge,

      
        134
        136
         		rateLimiterConfig,

      
M e2e/e2e_test.go
···
        140
        140
         	handler := httptransport.NewTransport(

      
        141
        141
         		usersrv,

      
        142
        142
         		notesrv,

      
        
        143
        +		cfg.AppEnv,

      
        
        144
        +		cfg.AppURL,

      
        143
        145
         		cfg.CORSAllowedOrigins,

      
        144
        146
         		cfg.CORSMaxAge,

      
        145
        147
         		ratelimitCfg,

      
M internal/config/config.go
···
        8
        8
         	"time"

      
        9
        9
         )

      
        10
        10
         

      
        
        11
        +type Environment string

      
        
        12
        +

      
        
        13
        +func (e Environment) IsDevMode() bool {

      
        
        14
        +	return e == "debug" || e == "test"

      
        
        15
        +}

      
        
        16
        +

      
        11
        17
         type Config struct {

      
        12
        
        -	AppEnv  string

      
        
        18
        +	AppEnv  Environment

      
        13
        19
         	AppURL  string

      
        14
        20
         	NatsURL string

      
        15
        21
         

      ···
        61
        67
         

      
        62
        68
         func NewConfig() *Config {

      
        63
        69
         	return &Config{

      
        64
        
        -		AppEnv:  getenvOrDefault("APP_ENV", "debug"),

      
        
        70
        +		AppEnv:  Environment(getenvOrDefault("APP_ENV", "debug")),

      
        65
        71
         		AppURL:  getenvOrDefault("APP_URL", ""),

      
        66
        72
         		NatsURL: getenvOrDefault("NATS_URL", ""),

      
        67
        73
         

      ···
        114
        120
         		RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10),

      
        115
        121
         		RateLimiterTTL:   mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")),

      
        116
        122
         	}

      
        117
        
        -}

      
        118
        
        -

      
        119
        
        -func (c *Config) IsDevMode() bool {

      
        120
        
        -	return c.AppEnv == "debug" || c.AppEnv == "test"

      
        121
        123
         }

      
        122
        124
         

      
        123
        125
         func getenvOrDefault(key, def string) string {

      
M internal/dtos/user.go
···
        30
        30
         	NewPassword string

      
        31
        31
         }

      
        32
        32
         

      
        
        33
        +type OAuthRedirect struct {

      
        
        34
        +	URL   string

      
        
        35
        +	State string

      
        
        36
        +}

      
        
        37
        +

      
        33
        38
         type Tokens struct {

      
        34
        39
         	Access  string

      
        35
        40
         	Refresh string

      
M internal/service/usersrv/oauth.go
···
        19
        19
         	githubProvider = "github"

      
        20
        20
         )

      
        21
        21
         

      
        22
        
        -func (u *UserSrv) GetOAuthURL(providerName string) (string, error) {

      
        
        22
        +func (u *UserSrv) GetOAuthURL(providerName string) (dtos.OAuthRedirect, error) {

      
        
        23
        +	state := uuid.Must(uuid.NewV4()).String()

      
        
        24
        +

      
        23
        25
         	switch providerName {

      
        24
        26
         	case googleProvider:

      
        25
        
        -		return u.googleOauth.GetAuthURL(""), nil

      
        
        27
        +		return dtos.OAuthRedirect{

      
        
        28
        +			URL:   u.googleOauth.GetAuthURL(state),

      
        
        29
        +			State: state,

      
        
        30
        +		}, nil

      
        26
        31
         	case githubProvider:

      
        27
        
        -		return u.githubOauth.GetAuthURL(""), nil

      
        
        32
        +		return dtos.OAuthRedirect{

      
        
        33
        +			URL:   u.githubOauth.GetAuthURL(state),

      
        
        34
        +			State: state,

      
        
        35
        +		}, nil

      
        28
        36
         	default:

      
        29
        
        -		return "", ErrProviderNotSupported

      
        
        37
        +		return dtos.OAuthRedirect{}, ErrProviderNotSupported

      
        30
        38
         	}

      
        31
        39
         }

      
        32
        40
         

      
M internal/service/usersrv/usersrv.go
···
        30
        30
         	RequestPasswordReset(ctx context.Context, inp dtos.RequestResetPassword) error

      
        31
        31
         	ResetPassword(ctx context.Context, inp dtos.ResetPassword) error

      
        32
        32
         

      
        33
        
        -	GetOAuthURL(providerName string) (string, error)

      
        
        33
        +	GetOAuthURL(providerName string) (dtos.OAuthRedirect, error)

      
        34
        34
         	HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error)

      
        35
        35
         

      
        36
        36
         	Verify(ctx context.Context, verificationKey string) error

      
M internal/store/psql/userepo/userepo.go
···
        157
        157
         	provider, providerID string,

      
        158
        158
         ) error {

      
        159
        159
         	query := `--sql

      
        160
        
        -	insert into oauth_identities (user_id, provider, provider_id)

      
        161
        
        -	values ($1, $2, $3)`

      
        
        160
        +insert into oauth_identities (user_id, provider, provider_id)

      
        
        161
        +values ($1, $2, $3)

      
        
        162
        +on conflict (provider, provider_id) do update

      
        
        163
        +set user_id = $1,

      
        
        164
        +	provider = $2,

      
        
        165
        +	provider_id = $3`

      
        162
        166
         

      
        163
        167
         	_, err := r.db.Exec(ctx, query, userID, provider, providerID)

      
        164
        168
         	return err

      
M internal/transport/http/apiv1/apiv1.go
···
        2
        2
         

      
        3
        3
         import (

      
        4
        4
         	"github.com/gin-gonic/gin"

      
        
        5
        +	"github.com/olexsmir/onasty/internal/config"

      
        5
        6
         	"github.com/olexsmir/onasty/internal/service/notesrv"

      
        6
        7
         	"github.com/olexsmir/onasty/internal/service/usersrv"

      
        7
        8
         )

      ···
        9
        10
         type APIV1 struct {

      
        10
        11
         	usersrv usersrv.UserServicer

      
        11
        12
         	notesrv notesrv.NoteServicer

      
        
        13
        +	env     config.Environment

      
        
        14
        +	domain  string

      
        12
        15
         }

      
        13
        16
         

      
        14
        17
         func NewAPIV1(

      
        15
        18
         	us usersrv.UserServicer,

      
        16
        19
         	ns notesrv.NoteServicer,

      
        
        20
        +	env config.Environment,

      
        
        21
        +	domain string,

      
        17
        22
         ) *APIV1 {

      
        18
        23
         	return &APIV1{

      
        19
        24
         		usersrv: us,

      
        20
        25
         		notesrv: ns,

      
        
        26
        +		env:     env,

      
        
        27
        +		domain:  domain,

      
        21
        28
         	}

      
        22
        29
         }

      
        23
        30
         

      
M internal/transport/http/apiv1/auth.go
···
        198
        198
         	c.Status(http.StatusOK)

      
        199
        199
         }

      
        200
        200
         

      
        
        201
        +const oatuhStateCookie = "oauth_state"

      
        
        202
        +

      
        201
        203
         func (a *APIV1) oauthLoginHandler(c *gin.Context) {

      
        202
        
        -	url, err := a.usersrv.GetOAuthURL(c.Param("provider"))

      
        
        204
        +	redirectInfo, err := a.usersrv.GetOAuthURL(c.Param("provider"))

      
        203
        205
         	if err != nil {

      
        204
        206
         		errorResponse(c, err)

      
        205
        207
         		return

      
        206
        208
         	}

      
        207
        209
         

      
        208
        
        -	c.Redirect(http.StatusSeeOther, url)

      
        
        210
        +	c.SetCookie(

      
        
        211
        +		oatuhStateCookie,

      
        
        212
        +		redirectInfo.State,

      
        
        213
        +		int(time.Minute.Seconds()),

      
        
        214
        +		"/",

      
        
        215
        +		a.domain,

      
        
        216
        +		!a.env.IsDevMode(),

      
        
        217
        +		true,

      
        
        218
        +	)

      
        
        219
        +

      
        
        220
        +	c.Redirect(http.StatusSeeOther, redirectInfo.URL)

      
        209
        221
         }

      
        210
        222
         

      
        211
        223
         func (a *APIV1) oauthCallbackHandler(c *gin.Context) {

      
        
        224
        +	state := c.Query("state")

      
        
        225
        +	storedState, err := c.Cookie(oatuhStateCookie)

      
        
        226
        +	if err != nil || state != storedState {

      
        
        227
        +		newError(c, http.StatusBadRequest, "invalid oauth state")

      
        
        228
        +		return

      
        
        229
        +	}

      
        
        230
        +

      
        212
        231
         	tokens, err := a.usersrv.HandleOAuthLogin(

      
        213
        232
         		c.Request.Context(),

      
        214
        233
         		c.Param("provider"),

      
M internal/transport/http/http.go
···
        5
        5
         	"time"

      
        6
        6
         

      
        7
        7
         	"github.com/gin-gonic/gin"

      
        
        8
        +	"github.com/olexsmir/onasty/internal/config"

      
        8
        9
         	"github.com/olexsmir/onasty/internal/service/notesrv"

      
        9
        10
         	"github.com/olexsmir/onasty/internal/service/usersrv"

      
        10
        11
         	"github.com/olexsmir/onasty/internal/transport/http/apiv1"

      ···
        16
        17
         	usersrv usersrv.UserServicer

      
        17
        18
         	notesrv notesrv.NoteServicer

      
        18
        19
         

      
        
        20
        +	env    config.Environment

      
        
        21
        +	domain string

      
        
        22
        +

      
        19
        23
         	corsAllowedOrigins []string

      
        20
        24
         	corsMaxAge         time.Duration

      
        21
        25
         	ratelimitCfg       ratelimit.Config

      ···
        24
        28
         func NewTransport(

      
        25
        29
         	us usersrv.UserServicer,

      
        26
        30
         	ns notesrv.NoteServicer,

      
        
        31
        +	env config.Environment,

      
        
        32
        +	domain string,

      
        27
        33
         	corsAllowedOrigins []string,

      
        28
        34
         	corsMaxAge time.Duration,

      
        29
        35
         	ratelimitCfg ratelimit.Config,

      ···
        31
        37
         	return &Transport{

      
        32
        38
         		usersrv:            us,

      
        33
        39
         		notesrv:            ns,

      
        
        40
        +		env:                env,

      
        
        41
        +		domain:             domain,

      
        34
        42
         		corsAllowedOrigins: corsAllowedOrigins,

      
        35
        43
         		corsMaxAge:         corsMaxAge,

      
        36
        44
         		ratelimitCfg:       ratelimitCfg,

      ···
        49
        57
         

      
        50
        58
         	api := r.Group("/api")

      
        51
        59
         	api.GET("/ping", t.pingHandler)

      
        52
        
        -	apiv1.NewAPIV1(t.usersrv, t.notesrv).Routes(api.Group("/v1"))

      
        
        60
        +	apiv1.NewAPIV1(t.usersrv, t.notesrv, t.env, t.domain).Routes(api.Group("/v1"))

      
        53
        61
         

      
        54
        62
         	return r.Handler()

      
        55
        63
         }