25 files changed,
676 insertions(+),
34 deletions(-)
Author:
Smirnov Oleksandr
ss2316544@gmail.com
Committed by:
GitHub
noreply@github.com
Committed at:
2025-05-09 15:56:51 +0300
Parent:
040e383
jump to
M
.env.example
··· 16 16 JWT_ACCESS_TOKEN_TTL=30m 17 17 JWT_REFRESH_TOKEN_TTL=360d 18 18 19 +GOOGLE_CLIENTID= 20 +GOOGLE_SECRET= 21 +GOOGLE_REDIRECTURL=http://localhost:8000/api/v1/oauth/google/callback 22 + 23 +GITHUB_CLIENTID= 24 +GITHUB_SECRET= 25 +GITHUB_REDIRECTURL=http://localhost:8000/api/v1/oauth/github/callback 26 + 19 27 POSTGRES_USERNAME=onasty 20 28 POSTGRES_PASSWORD=qwerty 21 29 POSTGRES_HOST=postgres
M
cmd/server/main.go
··· 17 17 "github.com/olexsmir/onasty/internal/jwtutil" 18 18 "github.com/olexsmir/onasty/internal/logger" 19 19 "github.com/olexsmir/onasty/internal/metrics" 20 + "github.com/olexsmir/onasty/internal/oauth" 20 21 "github.com/olexsmir/onasty/internal/service/notesrv" 21 22 "github.com/olexsmir/onasty/internal/service/usersrv" 22 23 "github.com/olexsmir/onasty/internal/store/psql/noterepo" ··· 79 80 notePasswordHasher := hasher.NewSHA256Hasher(cfg.NotePasswordSalt) 80 81 jwtTokenizer := jwtutil.NewJWTUtil(cfg.JwtSigningKey, cfg.JwtAccessTokenTTL) 81 82 83 + googleOauth := oauth.NewGoogleProvider( 84 + cfg.GoogleClientID, 85 + cfg.GoogleSecret, 86 + cfg.GoogleRedirectURL, 87 + ) 88 + githubOauth := oauth.NewGithubProvider( 89 + cfg.GitHubClientID, 90 + cfg.GitHubSecret, 91 + cfg.GitHubRedirectURL, 92 + ) 93 + 82 94 mailermq := mailermq.New(nc) 83 95 84 96 sessionrepo := sessionrepo.New(psqlDB) ··· 94 106 jwtTokenizer, 95 107 mailermq, 96 108 usercache, 109 + googleOauth, 110 + githubOauth, 97 111 cfg.JwtRefreshTokenTTL, 98 112 cfg.VerificationTokenTTL, 99 113 ) ··· 115 129 ) 116 130 117 131 // http server 118 - srv := httpserver.NewServer(handler.Handler(), httpConfig(cfg.HTTPPort, cfg)) 132 + srv := httpserver.NewServer(handler.Handler(), httpserver.Config{ 133 + Port: cfg.HTTPPort, 134 + ReadTimeout: cfg.HTTPReadTimeout, 135 + WriteTimeout: cfg.HTTPWriteTimeout, 136 + MaxHeaderSizeMb: cfg.HTTPHeaderMaxSizeMb, 137 + }) 119 138 go func() { 120 139 slog.Info("starting http server", "port", cfg.HTTPPort) 121 140 if err := srv.Start(); !errors.Is(err, http.ErrServerClosed) { ··· 125 144 126 145 // metrics 127 146 if cfg.MetricsEnabled { 128 - mSrv := httpserver.NewServer(metrics.Handler(), httpConfig(cfg.MetricsPort, cfg)) 147 + mSrv := httpserver.NewDefaultServer(metrics.Handler(), cfg.MetricsPort) 129 148 go func() { 130 149 slog.Info("starting metrics server", "port", cfg.MetricsPort) 131 150 if err := mSrv.Start(); !errors.Is(err, http.ErrServerClosed) { ··· 153 172 154 173 return nil 155 174 } 156 - 157 -func httpConfig(port string, cfg *config.Config) httpserver.Config { 158 - return httpserver.Config{ 159 - Port: port, 160 - ReadTimeout: cfg.HTTPReadTimeout, 161 - WriteTimeout: cfg.HTTPWriteTimeout, 162 - MaxHeaderSizeMb: cfg.HTTPHeaderMaxSizeMb, 163 - } 164 -}
M
e2e/e2e_test.go
··· 104 104 sessionrepo := sessionrepo.New(e.postgresDB) 105 105 vertokrepo := vertokrepo.New(e.postgresDB) 106 106 107 + oauthProvider := newOauthProviderMock() 108 + 107 109 userepo := userepo.New(e.postgresDB) 108 110 usercache := usercache.New(e.redisDB, cfg.CacheUsersTTL) 109 111 usersrv := usersrv.New( ··· 114 116 e.jwtTokenizer, 115 117 newMailerMockService(), 116 118 usercache, 119 + oauthProvider, 120 + oauthProvider, 117 121 cfg.JwtRefreshTokenTTL, 118 122 cfg.VerificationTokenTTL, 119 123 )
A
e2e/oauth_provider_mock_test.go
··· 1 +package e2e_test 2 + 3 +import ( 4 + "context" 5 + 6 + "github.com/olexsmir/onasty/internal/oauth" 7 +) 8 + 9 +var _ oauth.Provider = (*oauthProviderMock)(nil) 10 + 11 +type oauthProviderMock struct{} 12 + 13 +func newOauthProviderMock() *oauthProviderMock { 14 + return &oauthProviderMock{} 15 +} 16 + 17 +func (o *oauthProviderMock) GetAuthURL(_ string) string { 18 + return "https://example.com/oauth/authorize" 19 +} 20 + 21 +func (o *oauthProviderMock) ExchangeCode(_ context.Context, _ string) (oauth.UserInfo, error) { 22 + return oauth.UserInfo{ 23 + Provider: "google", 24 + ProviderID: "1234567890", 25 + Email: "testing@mail.org", 26 + EmailVerified: false, 27 + }, nil 28 +}
M
go.mod
··· 17 17 github.com/testcontainers/testcontainers-go v0.36.0 18 18 github.com/testcontainers/testcontainers-go/modules/postgres v0.36.0 19 19 github.com/testcontainers/testcontainers-go/modules/redis v0.36.0 20 + golang.org/x/oauth2 v0.29.0 20 21 golang.org/x/time v0.11.0 21 22 ) 22 23 23 24 require ( 25 + cloud.google.com/go/compute/metadata v0.3.0 // indirect 24 26 dario.cat/mergo v1.0.1 // indirect 25 27 github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect 26 28 github.com/Microsoft/go-winio v0.6.2 // indirect
M
go.sum
··· 1 +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= 2 +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= 1 3 dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= 2 4 dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= 3 5 github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= ··· 378 380 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 379 381 golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 380 382 golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 383 +golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= 384 +golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= 381 385 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 382 386 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 383 387 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
M
internal/config/config.go
··· 12 12 AppURL string 13 13 NatsURL string 14 14 15 - HTTPPort string 15 + HTTPPort int 16 16 HTTPWriteTimeout time.Duration 17 17 HTTPReadTimeout time.Duration 18 18 HTTPHeaderMaxSizeMb int ··· 32 32 JwtAccessTokenTTL time.Duration 33 33 JwtRefreshTokenTTL time.Duration 34 34 35 + GoogleClientID string 36 + GoogleSecret string 37 + GoogleRedirectURL string 38 + 39 + GitHubClientID string 40 + GitHubSecret string 41 + GitHubRedirectURL string 42 + 35 43 VerificationTokenTTL time.Duration 36 44 37 45 MetricsEnabled bool 38 - MetricsPort string 46 + MetricsPort int 39 47 40 48 LogLevel string 41 49 LogFormat string ··· 52 60 AppURL: getenvOrDefault("APP_URL", ""), 53 61 NatsURL: getenvOrDefault("NATS_URL", ""), 54 62 55 - HTTPPort: getenvOrDefault("HTTP_PORT", "3000"), 63 + HTTPPort: mustGetenvOrDefaultInt("HTTP_PORT", 3000), 56 64 HTTPWriteTimeout: mustParseDuration(getenvOrDefault("HTTP_WRITE_TIMEOUT", "10s")), 57 65 HTTPReadTimeout: mustParseDuration(getenvOrDefault("HTTP_READ_TIMEOUT", "10s")), 58 66 HTTPHeaderMaxSizeMb: mustGetenvOrDefaultInt("HTTP_HEADER_MAX_SIZE_MB", 1), ··· 76 84 getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"), 77 85 ), 78 86 87 + GoogleClientID: getenvOrDefault("GOOGLE_CLIENTID", ""), 88 + GoogleSecret: getenvOrDefault("GOOGLE_SECRET", ""), 89 + GoogleRedirectURL: getenvOrDefault("GOOGLE_REDIRECTURL", ""), 90 + 91 + GitHubClientID: getenvOrDefault("GITHUB_CLIENTID", ""), 92 + GitHubSecret: getenvOrDefault("GITHUB_SECRET", ""), 93 + GitHubRedirectURL: getenvOrDefault("GITHUB_REDIRECTURL", ""), 94 + 79 95 VerificationTokenTTL: mustParseDuration( 80 96 getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"), 81 97 ), 82 98 83 - MetricsPort: getenvOrDefault("METRICS_PORT", "3001"), 99 + MetricsPort: mustGetenvOrDefaultInt("METRICS_PORT", 3001), 84 100 MetricsEnabled: getenvOrDefault("METRICS_ENABLED", "true") == "true", 85 101 86 102 LogLevel: getenvOrDefault("LOG_LEVEL", "debug"),
A
internal/oauth/github.go
··· 1 +package oauth 2 + 3 +import ( 4 + "bytes" 5 + "context" 6 + "encoding/json" 7 + "io" 8 + "net/http" 9 + "strconv" 10 + 11 + "golang.org/x/oauth2" 12 + "golang.org/x/oauth2/github" 13 +) 14 + 15 +var _ Provider = (*GitHubProvider)(nil) 16 + 17 +const githubUserInfoEndpoint = "https://api.github.com/user" 18 + 19 +type GitHubProvider struct { 20 + config oauth2.Config 21 +} 22 + 23 +func NewGithubProvider(clientID, secret, redirectURL string) GitHubProvider { 24 + return GitHubProvider{ 25 + config: oauth2.Config{ 26 + ClientID: clientID, 27 + ClientSecret: secret, 28 + RedirectURL: redirectURL, 29 + Endpoint: github.Endpoint, 30 + Scopes: []string{ 31 + "user:email", 32 + }, 33 + }, 34 + } 35 +} 36 + 37 +func (g GitHubProvider) GetAuthURL(state string) string { 38 + return g.config.AuthCodeURL(state) 39 +} 40 + 41 +func (g GitHubProvider) ExchangeCode(ctx context.Context, code string) (UserInfo, error) { 42 + tok, err := g.config.Exchange(ctx, code) 43 + if err != nil { 44 + return UserInfo{}, err 45 + } 46 + 47 + client := g.config.Client(ctx, tok) 48 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, githubUserInfoEndpoint, nil) 49 + if err != nil { 50 + return UserInfo{}, err 51 + } 52 + 53 + resp, err := client.Do(req) 54 + if err != nil { 55 + return UserInfo{}, err 56 + } 57 + 58 + defer resp.Body.Close() 59 + 60 + b, err := io.ReadAll(resp.Body) 61 + if err != nil { 62 + return UserInfo{}, err 63 + } 64 + 65 + var data struct { 66 + ID int `json:"id"` 67 + Email string `json:"email"` 68 + } 69 + 70 + if err := json.NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { 71 + return UserInfo{}, err 72 + } 73 + 74 + return UserInfo{ 75 + Provider: "github", 76 + ProviderID: strconv.Itoa(data.ID), 77 + Email: data.Email, 78 + EmailVerified: true, 79 + }, nil 80 +}
A
internal/oauth/github_test.go
··· 1 +package oauth 2 + 3 +import ( 4 + "context" 5 + "errors" 6 + "fmt" 7 + "io" 8 + "net/http" 9 + "strings" 10 + "testing" 11 + 12 + "github.com/stretchr/testify/assert" 13 + "github.com/stretchr/testify/require" 14 + "golang.org/x/oauth2" 15 +) 16 + 17 +func TestGitHubProvider_GetAuthURL(t *testing.T) { 18 + provider := NewGithubProvider("client.id", "secret", "http://localhost/callback") 19 + url := provider.GetAuthURL("test") 20 + 21 + assert.Contains(t, url, "client_id=client.id") 22 + assert.Contains(t, url, "state=test") 23 + assert.Contains(t, url, "scope=user%3Aemail") 24 +} 25 + 26 +type mockClient func(*http.Request) (*http.Response, error) 27 + 28 +func (m mockClient) RoundTrip(req *http.Request) (*http.Response, error) { 29 + return m(req) 30 +} 31 + 32 +func TestGitHubProvider_ExchangeCode(t *testing.T) { 33 + userID := "123123" 34 + userEmail := "test@testing.org" 35 + userLogin := "testing" 36 + 37 + resp := fmt.Sprintf(`{"id":%s, "email":"%s", "login":"%s"}`, userID, userEmail, userLogin) 38 + client := &http.Client{ //nolint:exhaustruct 39 + Transport: mockClient(func(req *http.Request) (*http.Response, error) { 40 + if req.Method == http.MethodPost { 41 + return &http.Response{ //nolint:exhaustruct 42 + StatusCode: http.StatusOK, 43 + Header: http.Header{"Content-Type": []string{"application/json"}}, 44 + Body: io.NopCloser( 45 + strings.NewReader(`{"access_token":"fake", 46 + "token_type":"bearer", 47 + "expires_in":3600}`), 48 + ), 49 + }, nil 50 + } 51 + return &http.Response{ //nolint:exhaustruct 52 + StatusCode: http.StatusOK, 53 + Header: http.Header{"Content-Type": []string{"application/json"}}, 54 + Body: io.NopCloser(strings.NewReader(resp)), 55 + }, nil 56 + }), 57 + } 58 + 59 + provider := NewGithubProvider("client.id", "secret", "http://localhost") 60 + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) 61 + 62 + info, err := provider.ExchangeCode(ctx, "") 63 + require.NoError(t, err) 64 + assert.Equal(t, "github", info.Provider) 65 + assert.Equal(t, userID, info.ProviderID) 66 + assert.Equal(t, userEmail, info.Email) 67 +} 68 + 69 +func TestGitHubProvider_ExchangeCode_tokenExcahnge_error(t *testing.T) { 70 + client := &http.Client{ //nolint:exhaustruct 71 + Transport: mockClient(func(req *http.Request) (*http.Response, error) { 72 + if req.Method == http.MethodPost { 73 + return &http.Response{ //nolint:exhaustruct 74 + StatusCode: http.StatusBadRequest, 75 + Body: io.NopCloser(strings.NewReader("")), 76 + }, nil 77 + } 78 + return nil, errors.New("unexpected request") 79 + }), 80 + } 81 + 82 + provider := NewGithubProvider("client.id", "secret", "http://localhost") 83 + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) 84 + 85 + _, err := provider.ExchangeCode(ctx, "") 86 + require.Error(t, err) 87 +}
A
internal/oauth/google.go
··· 1 +package oauth 2 + 3 +import ( 4 + "bytes" 5 + "context" 6 + "encoding/json" 7 + "io" 8 + "net/http" 9 + 10 + "golang.org/x/oauth2" 11 + "golang.org/x/oauth2/google" 12 +) 13 + 14 +var _ Provider = (*GoogleProvider)(nil) 15 + 16 +const googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo" 17 + 18 +type GoogleProvider struct { 19 + config oauth2.Config 20 +} 21 + 22 +func NewGoogleProvider(clientID, secret, redirectURL string) GoogleProvider { 23 + return GoogleProvider{ 24 + config: oauth2.Config{ 25 + ClientID: clientID, 26 + ClientSecret: secret, 27 + RedirectURL: redirectURL, 28 + Endpoint: google.Endpoint, 29 + Scopes: []string{ 30 + "https://www.googleapis.com/auth/userinfo.email", 31 + }, 32 + }, 33 + } 34 +} 35 + 36 +func (g GoogleProvider) GetAuthURL(state string) string { 37 + return g.config.AuthCodeURL(state) 38 +} 39 + 40 +func (g GoogleProvider) ExchangeCode(ctx context.Context, code string) (UserInfo, error) { 41 + tok, err := g.config.Exchange(ctx, code) 42 + if err != nil { 43 + return UserInfo{}, err 44 + } 45 + 46 + client := g.config.Client(ctx, tok) 47 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleUserInfoEndpoint, nil) 48 + if err != nil { 49 + return UserInfo{}, err 50 + } 51 + 52 + resp, err := client.Do(req) 53 + if err != nil { 54 + return UserInfo{}, err 55 + } 56 + 57 + defer resp.Body.Close() 58 + 59 + b, err := io.ReadAll(resp.Body) 60 + if err != nil { 61 + return UserInfo{}, err 62 + } 63 + 64 + var data struct { 65 + Sub string `json:"sub"` 66 + Email string `json:"email"` 67 + EmailVerified bool `json:"email_verified"` 68 + } 69 + 70 + if err := json.NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { 71 + return UserInfo{}, err 72 + } 73 + 74 + return UserInfo{ 75 + Provider: "google", 76 + ProviderID: data.Sub, 77 + Email: data.Email, 78 + EmailVerified: data.EmailVerified, 79 + }, nil 80 +}
A
internal/oauth/google_test.go
··· 1 +package oauth 2 + 3 +import ( 4 + "context" 5 + "fmt" 6 + "io" 7 + "net/http" 8 + "net/url" 9 + "strings" 10 + "testing" 11 + 12 + "github.com/stretchr/testify/assert" 13 + "github.com/stretchr/testify/require" 14 + "golang.org/x/oauth2" 15 +) 16 + 17 +func TestGoogleProvider_GetAuthURL(t *testing.T) { 18 + provider := NewGoogleProvider("client.id", "secret", "http://localhost/callback") 19 + authURL := provider.GetAuthURL("test") 20 + 21 + assert.Contains(t, authURL, "client_id=client.id") 22 + assert.Contains(t, authURL, "state=test") 23 + assert.Contains(t, authURL, "scope="+ 24 + url.QueryEscape("https://www.googleapis.com/auth/userinfo.email")) 25 +} 26 + 27 +func TestGoogleProvider_ExchangeCode(t *testing.T) { 28 + sub := "1234567890" 29 + email := "testemail@mail.com" 30 + resp := fmt.Sprintf(`{"sub":"%s", "email":"%s","email_verified":true}`, sub, email) 31 + client := &http.Client{ //nolint:exhaustruct 32 + Transport: mockClient(func(req *http.Request) (*http.Response, error) { 33 + if req.Method == http.MethodPost { 34 + return &http.Response{ //nolint:exhaustruct 35 + StatusCode: http.StatusOK, 36 + Header: http.Header{"Content-Type": []string{"application/json"}}, 37 + Body: io.NopCloser( 38 + strings.NewReader(`{"access_token":"fake", 39 + "token_type":"bearer", 40 + "expires_in":3600}`), 41 + ), 42 + }, nil 43 + } 44 + return &http.Response{ //nolint:exhaustruct 45 + StatusCode: http.StatusOK, 46 + Header: http.Header{"Content-Type": []string{"application/json"}}, 47 + Body: io.NopCloser(strings.NewReader(resp)), 48 + }, nil 49 + }), 50 + } 51 + 52 + provider := NewGoogleProvider("client.id", "secret", "http://localhost") 53 + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) 54 + 55 + info, err := provider.ExchangeCode(ctx, "") 56 + require.NoError(t, err) 57 + assert.Equal(t, "google", info.Provider) 58 + assert.Equal(t, sub, info.ProviderID) 59 + assert.Equal(t, email, info.Email) 60 + assert.True(t, info.EmailVerified) 61 +}
A
internal/oauth/oauth.go
··· 1 +package oauth 2 + 3 +import "context" 4 + 5 +// Provider is an OAuth interface. 6 +type Provider interface { 7 + // GetAuthURL return the provider's authorization page URL. 8 + GetAuthURL(state string) string 9 + 10 + // ExchangeCode exchanges the provided authorization code for user information. 11 + ExchangeCode(ctx context.Context, code string) (UserInfo, error) 12 +} 13 + 14 +// UserInfo represents the user information returned by the OAuth provider. 15 +type UserInfo struct { 16 + // Provider is the name of the OAuth provider 17 + Provider string 18 + // ProviderID is user ID assigned by the provider 19 + ProviderID string 20 + // Email is user's email address returned by the provider 21 + Email string 22 + // EmailVerified indicates whether the email was verified by the provider 23 + EmailVerified bool 24 +}
A
internal/service/usersrv/oauth.go
··· 1 +package usersrv 2 + 3 +import ( 4 + "context" 5 + "errors" 6 + "log/slog" 7 + "strings" 8 + "time" 9 + 10 + "github.com/gofrs/uuid/v5" 11 + "github.com/olexsmir/onasty/internal/dtos" 12 + "github.com/olexsmir/onasty/internal/models" 13 + "github.com/olexsmir/onasty/internal/oauth" 14 +) 15 + 16 +var ErrProviderNotSupported = errors.New("oauth2 provider not supported") 17 + 18 +const ( 19 + googleProvider = "google" 20 + githubProvider = "github" 21 +) 22 + 23 +func (u *UserSrv) GetOAuthURL(providerName string) (string, error) { 24 + switch providerName { 25 + case googleProvider: 26 + return u.googleOauth.GetAuthURL(""), nil 27 + case githubProvider: 28 + return u.githubOauth.GetAuthURL(""), nil 29 + default: 30 + return "", ErrProviderNotSupported 31 + } 32 +} 33 + 34 +func (u *UserSrv) HandleOAuthLogin( 35 + ctx context.Context, 36 + providerName, code string, 37 +) (dtos.Tokens, error) { 38 + userInfo, err := u.getUserInfoBasedOnProvider(ctx, providerName, code) 39 + if err != nil { 40 + return dtos.Tokens{}, err 41 + } 42 + 43 + userID, err := u.getUserByOAuthIDOrCreateOne(ctx, userInfo) 44 + if err != nil { 45 + return dtos.Tokens{}, err 46 + } 47 + 48 + if err = u.userstore.LinkOAuthIdentity(ctx, userID, userInfo.Provider, userInfo.ProviderID); err != nil { 49 + slog.ErrorContext(ctx, "failed to link user identity", "user_id", userID, "err", err) 50 + return dtos.Tokens{}, err 51 + } 52 + 53 + tokens, err := u.issueTokens(ctx, userID) 54 + 55 + return tokens, err 56 +} 57 + 58 +func (u *UserSrv) getUserInfoBasedOnProvider( 59 + ctx context.Context, 60 + providerName, code string, 61 +) (oauth.UserInfo, error) { 62 + var userInfo oauth.UserInfo 63 + var err error 64 + 65 + switch providerName { 66 + case googleProvider: 67 + userInfo, err = u.googleOauth.ExchangeCode(ctx, code) 68 + case githubProvider: 69 + userInfo, err = u.githubOauth.ExchangeCode(ctx, code) 70 + default: 71 + return oauth.UserInfo{}, ErrProviderNotSupported 72 + } 73 + 74 + return userInfo, err 75 +} 76 + 77 +func getUsernameFromEmail(email string) string { 78 + p := strings.Split(email, "@") 79 + return p[0] 80 +} 81 + 82 +func (u *UserSrv) getUserByOAuthIDOrCreateOne( 83 + ctx context.Context, 84 + info oauth.UserInfo, 85 +) (uuid.UUID, error) { 86 + user, err := u.userstore.GetByOAuthID(ctx, info.Provider, info.ProviderID) 87 + if err != nil { 88 + if errors.Is(err, models.ErrUserNotFound) { 89 + uid, cerr := u.userstore.Create(ctx, models.User{ 90 + ID: uuid.Nil, 91 + Username: getUsernameFromEmail(info.Email), 92 + Email: info.Email, 93 + Activated: true, 94 + Password: "", 95 + CreatedAt: time.Now(), 96 + LastLoginAt: time.Now(), 97 + }) 98 + return uid, cerr 99 + } 100 + return uuid.Nil, err 101 + } 102 + 103 + return user.ID, nil 104 +}
M
internal/service/usersrv/usersrv.go
··· 12 12 "github.com/olexsmir/onasty/internal/hasher" 13 13 "github.com/olexsmir/onasty/internal/jwtutil" 14 14 "github.com/olexsmir/onasty/internal/models" 15 + "github.com/olexsmir/onasty/internal/oauth" 15 16 "github.com/olexsmir/onasty/internal/store/psql/sessionrepo" 16 17 "github.com/olexsmir/onasty/internal/store/psql/userepo" 17 18 "github.com/olexsmir/onasty/internal/store/psql/vertokrepo" ··· 25 26 Logout(ctx context.Context, userID uuid.UUID) error 26 27 27 28 ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ChangeUserPassword) error 29 + 30 + GetOAuthURL(providerName string) (string, error) 31 + HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error) 28 32 29 33 Verify(ctx context.Context, verificationKey string) error 30 34 ResendVerificationEmail(ctx context.Context, credentials dtos.SignIn) error ··· 45 49 jwtTokenizer jwtutil.JWTTokenizer 46 50 mailermq mailermq.Mailer 47 51 cache usercache.UserCacheer 52 + googleOauth oauth.Provider 53 + githubOauth oauth.Provider 48 54 49 55 refreshTokenTTL time.Duration 50 56 verificationTokenTTL time.Duration ··· 58 64 jwtTokenizer jwtutil.JWTTokenizer, 59 65 mailermq mailermq.Mailer, 60 66 cache usercache.UserCacheer, 67 + googleOauth, githubOauth oauth.Provider, 61 68 refreshTokenTTL, verificationTokenTTL time.Duration, 62 69 ) *UserSrv { 63 70 return &UserSrv{ ··· 68 75 jwtTokenizer: jwtTokenizer, 69 76 mailermq: mailermq, 70 77 cache: cache, 78 + googleOauth: googleOauth, 79 + githubOauth: githubOauth, 71 80 refreshTokenTTL: refreshTokenTTL, 72 81 verificationTokenTTL: verificationTokenTTL, 73 82 } ··· 135 144 return dtos.Tokens{}, models.ErrUserIsNotActivated 136 145 } 137 146 138 - tokens, err := u.createTokens(user.ID) 139 - if err != nil { 140 - return dtos.Tokens{}, err 141 - } 142 - 143 - if err := u.sessionstore.Set(ctx, user.ID, tokens.Refresh, time.Now().Add(u.refreshTokenTTL)); err != nil { 144 - return dtos.Tokens{}, err 145 - } 146 - 147 - return dtos.Tokens{ 148 - Access: tokens.Access, 149 - Refresh: tokens.Refresh, 150 - }, nil 147 + tokens, err := u.issueTokens(ctx, user.ID) 148 + return tokens, err 151 149 } 152 150 153 151 func (u *UserSrv) Logout(ctx context.Context, userID uuid.UUID) error { ··· 301 299 Refresh: refreshToken, 302 300 }, err 303 301 } 302 + 303 +func (u UserSrv) issueTokens(ctx context.Context, userID uuid.UUID) (dtos.Tokens, error) { 304 + toks, err := u.createTokens(userID) 305 + if err != nil { 306 + return dtos.Tokens{}, err 307 + } 308 + 309 + if err := u.sessionstore.Set(ctx, userID, toks.Refresh, time.Now().Add(u.refreshTokenTTL)); err != nil { 310 + return dtos.Tokens{}, err 311 + } 312 + 313 + return toks, nil 314 +}
M
internal/store/psql/userepo/userepo.go
··· 29 29 // password should be hashed 30 30 SetPassword(ctx context.Context, userID uuid.UUID, newPassword string) error 31 31 32 + GetByOAuthID(ctx context.Context, provider, providerID string) (models.User, error) 33 + LinkOAuthIdentity(ctx context.Context, userID uuid.UUID, provider, providerID string) error 34 + 32 35 CheckIfUserExists(ctx context.Context, userID uuid.UUID) (bool, error) 33 36 CheckIfUserIsActivated(ctx context.Context, userID uuid.UUID) (bool, error) 34 37 } ··· 111 114 } 112 115 113 116 return id, err 117 +} 118 + 119 +func (r *UserRepo) GetByOAuthID( 120 + ctx context.Context, 121 + provider, providerID string, 122 +) (models.User, error) { 123 + query := `--sql 124 + select u.id, u.username, u.email, u.password, u.activated, u.created_at, u.last_login_at 125 + from users u 126 + join oauth_identities oi on u.id = oi.user_id 127 + where oi.provider = $1 128 + and oi.provider_id = $2 129 + limit 1` 130 + 131 + var user models.User 132 + err := r.db.QueryRow(ctx, query, provider, providerID). 133 + Scan(&user.ID, &user.Username, &user.Email, &user.Password, &user.Activated, &user.CreatedAt, &user.LastLoginAt) 134 + if errors.Is(err, pgx.ErrNoRows) { 135 + return models.User{}, models.ErrUserNotFound 136 + } 137 + 138 + return user, err 139 +} 140 + 141 +func (r *UserRepo) LinkOAuthIdentity( 142 + ctx context.Context, 143 + userID uuid.UUID, 144 + provider, providerID string, 145 +) error { 146 + query := `--sql 147 + insert into oauth_identities (user_id, provider, provider_id) 148 + values ($1, $2, $3)` 149 + 150 + _, err := r.db.Exec(ctx, query, userID, provider, providerID) 151 + return err 114 152 } 115 153 116 154 func (r *UserRepo) MarkUserAsActivated(ctx context.Context, id uuid.UUID) error {
M
internal/transport/http/apiv1/apiv1.go
··· 36 36 authorized.POST("/logout", a.logOutHandler) 37 37 authorized.POST("/change-password", a.changePasswordHandler) 38 38 } 39 + 40 + oauth := r.Group("/oauth") 41 + { 42 + oauth.GET("/:provider", a.oauthLoginHandler) 43 + oauth.GET("/:provider/callback", a.oauthCallbackHandler) 44 + } 39 45 } 40 46 41 47 note := r.Group("/note", a.couldBeAuthorizedMiddleware)
M
internal/transport/http/apiv1/auth.go
··· 153 153 154 154 c.Status(http.StatusOK) 155 155 } 156 + 157 +func (a *APIV1) oauthLoginHandler(c *gin.Context) { 158 + url, err := a.usersrv.GetOAuthURL(c.Param("provider")) 159 + if err != nil { 160 + errorResponse(c, err) 161 + return 162 + } 163 + 164 + c.Redirect(http.StatusSeeOther, url) 165 +} 166 + 167 +func (a *APIV1) oauthCallbackHandler(c *gin.Context) { 168 + tokens, err := a.usersrv.HandleOAuthLogin( 169 + c.Request.Context(), 170 + c.Param("provider"), 171 + c.Query("code"), 172 + ) 173 + if err != nil { 174 + errorResponse(c, err) 175 + return 176 + } 177 + 178 + c.JSON(http.StatusOK, signInResponse{ 179 + AccessToken: tokens.Access, 180 + RefreshToken: tokens.Refresh, 181 + }) 182 +}
M
internal/transport/http/apiv1/response.go
··· 7 7 8 8 "github.com/gin-gonic/gin" 9 9 "github.com/olexsmir/onasty/internal/models" 10 + "github.com/olexsmir/onasty/internal/service/usersrv" 10 11 ) 11 12 12 13 var ErrUnauthorized = errors.New("unauthorized") ··· 16 17 } 17 18 18 19 func errorResponse(c *gin.Context, err error) { 19 - if errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || 20 + if errors.Is(err, usersrv.ErrProviderNotSupported) || 21 + errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || 20 22 errors.Is(err, models.ErrUsernameIsAlreadyInUse) || 21 23 errors.Is(err, models.ErrUserIsAlreadyVerified) || 22 24 errors.Is(err, models.ErrUserIsNotActivated) ||
M
internal/transport/http/httpserver/httpserver.go
··· 3 3 import ( 4 4 "context" 5 5 "net/http" 6 + "strconv" 6 7 "time" 7 8 ) 8 9 ··· 12 13 13 14 type Config struct { 14 15 // Port http server port 15 - Port string 16 + Port int 16 17 17 18 // ReadTimeout read timeout 18 19 ReadTimeout time.Duration ··· 25 26 } 26 27 27 28 func NewServer(handler http.Handler, cfg Config) *Server { 29 + p := strconv.Itoa(cfg.Port) 28 30 return &Server{ 29 31 http: &http.Server{ 30 - Addr: ":" + cfg.Port, 32 + Addr: ":" + p, 31 33 Handler: handler, 32 34 ReadTimeout: cfg.ReadTimeout, 33 35 WriteTimeout: cfg.WriteTimeout, 34 36 MaxHeaderBytes: cfg.MaxHeaderSizeMb << 20, 37 + }, 38 + } 39 +} 40 + 41 +// NewDefaultServer returns http server with default config 42 +func NewDefaultServer(handler http.Handler, port int) *Server { 43 + p := strconv.Itoa(port) 44 + return &Server{ 45 + http: &http.Server{ 46 + Addr: ":" + p, 47 + Handler: handler, 48 + ReadTimeout: 10 * time.Second, 49 + WriteTimeout: 10 * time.Second, 50 + MaxHeaderBytes: 1 << 20, 35 51 }, 36 52 } 37 53 }
M
mailer/config.go
··· 1 1 package main 2 2 3 -import "os" 3 +import ( 4 + "os" 5 + "strconv" 6 +) 4 7 5 8 type Config struct { 6 9 AppURL string ··· 14 17 LogShowLine bool 15 18 16 19 MetricsEnabled bool 17 - MetricsPort string 20 + MetricsPort int 18 21 } 19 22 20 23 func NewConfig() *Config { ··· 27 30 LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), 28 31 LogFormat: getenvOrDefault("LOG_FORMAT", "json"), 29 32 LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true", 30 - MetricsPort: getenvOrDefault("METRICS_PORT", ""), 33 + MetricsPort: mustGetenvOrDefaultInt("METRICS_PORT", 8001), 31 34 MetricsEnabled: getenvOrDefault("METRICS_ENABLED", "true") == "true", 32 35 } 33 36 } ··· 38 41 } 39 42 return def 40 43 } 44 + 45 +func mustGetenvOrDefaultInt(key string, def int) int { 46 + if v, ok := os.LookupEnv(key); ok { 47 + r, err := strconv.Atoi(v) 48 + if err != nil { 49 + panic(err) 50 + } 51 + return r 52 + } 53 + return def 54 +}
A
migrations/20250503143510_user_oauth.down.sql
··· 1 +ALTER TABLE users 2 + ALTER COLUMN PASSWORD SET NOT NULL;
A
migrations/20250503143510_user_oauth.up.sql
··· 1 +ALTER TABLE users 2 + ALTER COLUMN PASSWORD DROP NOT NULL;
A
migrations/20250503144157_oauth_identities.down.sql
··· 1 +DROP TABLE oauth_identities; 2 + 3 +DROP TYPE provider_enum;
A
migrations/20250503144157_oauth_identities.up.sql
··· 1 +CREATE TYPE provider_enum AS ENUM ( 2 + 'google', 3 + 'github' 4 +); 5 + 6 +CREATE TABLE oauth_identities ( 7 + id uuid PRIMARY KEY DEFAULT uuid_generate_v4 (), 8 + user_id uuid REFERENCES users (id) ON DELETE CASCADE, 9 + provider provider_enum NOT NULL, 10 + provider_id varchar(50), 11 + created_at timestamptz NOT NULL DEFAULT now(), 12 + UNIQUE (provider, provider_id) 13 +);