@@ -16,6 +16,14 @@ JWT_SIGNING_KEY=supersecret
JWT_ACCESS_TOKEN_TTL=30m JWT_REFRESH_TOKEN_TTL=360d +GOOGLE_CLIENTID= +GOOGLE_SECRET= +GOOGLE_REDIRECTURL=http://localhost:8000/api/v1/oauth/google/callback + +GITHUB_CLIENTID= +GITHUB_SECRET= +GITHUB_REDIRECTURL=http://localhost:8000/api/v1/oauth/github/callback + POSTGRES_USERNAME=onasty POSTGRES_PASSWORD=qwerty POSTGRES_HOST=postgres
@@ -17,6 +17,7 @@ "github.com/olexsmir/onasty/internal/hasher"
"github.com/olexsmir/onasty/internal/jwtutil" "github.com/olexsmir/onasty/internal/logger" "github.com/olexsmir/onasty/internal/metrics" + "github.com/olexsmir/onasty/internal/oauth" "github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" "github.com/olexsmir/onasty/internal/store/psql/noterepo"@@ -79,6 +80,17 @@ userPasswordHasher := hasher.NewSHA256Hasher(cfg.PasswordSalt)
notePasswordHasher := hasher.NewSHA256Hasher(cfg.NotePasswordSalt) jwtTokenizer := jwtutil.NewJWTUtil(cfg.JwtSigningKey, cfg.JwtAccessTokenTTL) + googleOauth := oauth.NewGoogleProvider( + cfg.GoogleClientID, + cfg.GoogleSecret, + cfg.GoogleRedirectURL, + ) + githubOauth := oauth.NewGithubProvider( + cfg.GitHubClientID, + cfg.GitHubSecret, + cfg.GitHubRedirectURL, + ) + mailermq := mailermq.New(nc) sessionrepo := sessionrepo.New(psqlDB)@@ -94,6 +106,8 @@ userPasswordHasher,
jwtTokenizer, mailermq, usercache, + googleOauth, + githubOauth, cfg.JwtRefreshTokenTTL, cfg.VerificationTokenTTL, )@@ -115,7 +129,12 @@ rateLimiterConfig,
) // http server - srv := httpserver.NewServer(handler.Handler(), httpConfig(cfg.HTTPPort, cfg)) + srv := httpserver.NewServer(handler.Handler(), httpserver.Config{ + Port: cfg.HTTPPort, + ReadTimeout: cfg.HTTPReadTimeout, + WriteTimeout: cfg.HTTPWriteTimeout, + MaxHeaderSizeMb: cfg.HTTPHeaderMaxSizeMb, + }) go func() { slog.Info("starting http server", "port", cfg.HTTPPort) if err := srv.Start(); !errors.Is(err, http.ErrServerClosed) {@@ -125,7 +144,7 @@ }()
// metrics if cfg.MetricsEnabled { - mSrv := httpserver.NewServer(metrics.Handler(), httpConfig(cfg.MetricsPort, cfg)) + mSrv := httpserver.NewDefaultServer(metrics.Handler(), cfg.MetricsPort) go func() { slog.Info("starting metrics server", "port", cfg.MetricsPort) if err := mSrv.Start(); !errors.Is(err, http.ErrServerClosed) {@@ -153,12 +172,3 @@ }
return nil } - -func httpConfig(port string, cfg *config.Config) httpserver.Config { - return httpserver.Config{ - Port: port, - ReadTimeout: cfg.HTTPReadTimeout, - WriteTimeout: cfg.HTTPWriteTimeout, - MaxHeaderSizeMb: cfg.HTTPHeaderMaxSizeMb, - } -}
@@ -104,6 +104,8 @@
sessionrepo := sessionrepo.New(e.postgresDB) vertokrepo := vertokrepo.New(e.postgresDB) + oauthProvider := newOauthProviderMock() + userepo := userepo.New(e.postgresDB) usercache := usercache.New(e.redisDB, cfg.CacheUsersTTL) usersrv := usersrv.New(@@ -114,6 +116,8 @@ e.hasher,
e.jwtTokenizer, newMailerMockService(), usercache, + oauthProvider, + oauthProvider, cfg.JwtRefreshTokenTTL, cfg.VerificationTokenTTL, )
@@ -0,0 +1,28 @@
+package e2e_test + +import ( + "context" + + "github.com/olexsmir/onasty/internal/oauth" +) + +var _ oauth.Provider = (*oauthProviderMock)(nil) + +type oauthProviderMock struct{} + +func newOauthProviderMock() *oauthProviderMock { + return &oauthProviderMock{} +} + +func (o *oauthProviderMock) GetAuthURL(_ string) string { + return "https://example.com/oauth/authorize" +} + +func (o *oauthProviderMock) ExchangeCode(_ context.Context, _ string) (oauth.UserInfo, error) { + return oauth.UserInfo{ + Provider: "google", + ProviderID: "1234567890", + Email: "testing@mail.org", + EmailVerified: false, + }, nil +}
@@ -17,10 +17,12 @@ github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.36.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.36.0 github.com/testcontainers/testcontainers-go/modules/redis v0.36.0 + golang.org/x/oauth2 v0.29.0 golang.org/x/time v0.11.0 ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect dario.cat/mergo v1.0.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
@@ -1,3 +1,5 @@
+cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=@@ -378,6 +380,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= +golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -12,7 +12,7 @@ AppEnv string
AppURL string NatsURL string - HTTPPort string + HTTPPort int HTTPWriteTimeout time.Duration HTTPReadTimeout time.Duration HTTPHeaderMaxSizeMb int@@ -32,10 +32,18 @@ JwtSigningKey string
JwtAccessTokenTTL time.Duration JwtRefreshTokenTTL time.Duration + GoogleClientID string + GoogleSecret string + GoogleRedirectURL string + + GitHubClientID string + GitHubSecret string + GitHubRedirectURL string + VerificationTokenTTL time.Duration MetricsEnabled bool - MetricsPort string + MetricsPort int LogLevel string LogFormat string@@ -52,7 +60,7 @@ AppEnv: getenvOrDefault("APP_ENV", "debug"),
AppURL: getenvOrDefault("APP_URL", ""), NatsURL: getenvOrDefault("NATS_URL", ""), - HTTPPort: getenvOrDefault("HTTP_PORT", "3000"), + HTTPPort: mustGetenvOrDefaultInt("HTTP_PORT", 3000), HTTPWriteTimeout: mustParseDuration(getenvOrDefault("HTTP_WRITE_TIMEOUT", "10s")), HTTPReadTimeout: mustParseDuration(getenvOrDefault("HTTP_READ_TIMEOUT", "10s")), HTTPHeaderMaxSizeMb: mustGetenvOrDefaultInt("HTTP_HEADER_MAX_SIZE_MB", 1),@@ -76,11 +84,19 @@ JwtRefreshTokenTTL: mustParseDuration(
getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"), ), + GoogleClientID: getenvOrDefault("GOOGLE_CLIENTID", ""), + GoogleSecret: getenvOrDefault("GOOGLE_SECRET", ""), + GoogleRedirectURL: getenvOrDefault("GOOGLE_REDIRECTURL", ""), + + GitHubClientID: getenvOrDefault("GITHUB_CLIENTID", ""), + GitHubSecret: getenvOrDefault("GITHUB_SECRET", ""), + GitHubRedirectURL: getenvOrDefault("GITHUB_REDIRECTURL", ""), + VerificationTokenTTL: mustParseDuration( getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"), ), - MetricsPort: getenvOrDefault("METRICS_PORT", "3001"), + MetricsPort: mustGetenvOrDefaultInt("METRICS_PORT", 3001), MetricsEnabled: getenvOrDefault("METRICS_ENABLED", "true") == "true", LogLevel: getenvOrDefault("LOG_LEVEL", "debug"),
@@ -0,0 +1,80 @@
+package oauth + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strconv" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +var _ Provider = (*GitHubProvider)(nil) + +const githubUserInfoEndpoint = "https://api.github.com/user" + +type GitHubProvider struct { + config oauth2.Config +} + +func NewGithubProvider(clientID, secret, redirectURL string) GitHubProvider { + return GitHubProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: secret, + RedirectURL: redirectURL, + Endpoint: github.Endpoint, + Scopes: []string{ + "user:email", + }, + }, + } +} + +func (g GitHubProvider) GetAuthURL(state string) string { + return g.config.AuthCodeURL(state) +} + +func (g GitHubProvider) ExchangeCode(ctx context.Context, code string) (UserInfo, error) { + tok, err := g.config.Exchange(ctx, code) + if err != nil { + return UserInfo{}, err + } + + client := g.config.Client(ctx, tok) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, githubUserInfoEndpoint, nil) + if err != nil { + return UserInfo{}, err + } + + resp, err := client.Do(req) + if err != nil { + return UserInfo{}, err + } + + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + return UserInfo{}, err + } + + var data struct { + ID int `json:"id"` + Email string `json:"email"` + } + + if err := json.NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { + return UserInfo{}, err + } + + return UserInfo{ + Provider: "github", + ProviderID: strconv.Itoa(data.ID), + Email: data.Email, + EmailVerified: true, + }, nil +}
@@ -0,0 +1,87 @@
+package oauth + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGitHubProvider_GetAuthURL(t *testing.T) { + provider := NewGithubProvider("client.id", "secret", "http://localhost/callback") + url := provider.GetAuthURL("test") + + assert.Contains(t, url, "client_id=client.id") + assert.Contains(t, url, "state=test") + assert.Contains(t, url, "scope=user%3Aemail") +} + +type mockClient func(*http.Request) (*http.Response, error) + +func (m mockClient) RoundTrip(req *http.Request) (*http.Response, error) { + return m(req) +} + +func TestGitHubProvider_ExchangeCode(t *testing.T) { + userID := "123123" + userEmail := "test@testing.org" + userLogin := "testing" + + resp := fmt.Sprintf(`{"id":%s, "email":"%s", "login":"%s"}`, userID, userEmail, userLogin) + client := &http.Client{ //nolint:exhaustruct + Transport: mockClient(func(req *http.Request) (*http.Response, error) { + if req.Method == http.MethodPost { + return &http.Response{ //nolint:exhaustruct + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser( + strings.NewReader(`{"access_token":"fake", + "token_type":"bearer", + "expires_in":3600}`), + ), + }, nil + } + return &http.Response{ //nolint:exhaustruct + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(resp)), + }, nil + }), + } + + provider := NewGithubProvider("client.id", "secret", "http://localhost") + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) + + info, err := provider.ExchangeCode(ctx, "") + require.NoError(t, err) + assert.Equal(t, "github", info.Provider) + assert.Equal(t, userID, info.ProviderID) + assert.Equal(t, userEmail, info.Email) +} + +func TestGitHubProvider_ExchangeCode_tokenExcahnge_error(t *testing.T) { + client := &http.Client{ //nolint:exhaustruct + Transport: mockClient(func(req *http.Request) (*http.Response, error) { + if req.Method == http.MethodPost { + return &http.Response{ //nolint:exhaustruct + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader("")), + }, nil + } + return nil, errors.New("unexpected request") + }), + } + + provider := NewGithubProvider("client.id", "secret", "http://localhost") + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) + + _, err := provider.ExchangeCode(ctx, "") + require.Error(t, err) +}
@@ -0,0 +1,80 @@
+package oauth + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +var _ Provider = (*GoogleProvider)(nil) + +const googleUserInfoEndpoint = "https://www.googleapis.com/oauth2/v3/userinfo" + +type GoogleProvider struct { + config oauth2.Config +} + +func NewGoogleProvider(clientID, secret, redirectURL string) GoogleProvider { + return GoogleProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: secret, + RedirectURL: redirectURL, + Endpoint: google.Endpoint, + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + }, + }, + } +} + +func (g GoogleProvider) GetAuthURL(state string) string { + return g.config.AuthCodeURL(state) +} + +func (g GoogleProvider) ExchangeCode(ctx context.Context, code string) (UserInfo, error) { + tok, err := g.config.Exchange(ctx, code) + if err != nil { + return UserInfo{}, err + } + + client := g.config.Client(ctx, tok) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleUserInfoEndpoint, nil) + if err != nil { + return UserInfo{}, err + } + + resp, err := client.Do(req) + if err != nil { + return UserInfo{}, err + } + + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + return UserInfo{}, err + } + + var data struct { + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + } + + if err := json.NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { + return UserInfo{}, err + } + + return UserInfo{ + Provider: "google", + ProviderID: data.Sub, + Email: data.Email, + EmailVerified: data.EmailVerified, + }, nil +}
@@ -0,0 +1,61 @@
+package oauth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGoogleProvider_GetAuthURL(t *testing.T) { + provider := NewGoogleProvider("client.id", "secret", "http://localhost/callback") + authURL := provider.GetAuthURL("test") + + assert.Contains(t, authURL, "client_id=client.id") + assert.Contains(t, authURL, "state=test") + assert.Contains(t, authURL, "scope="+ + url.QueryEscape("https://www.googleapis.com/auth/userinfo.email")) +} + +func TestGoogleProvider_ExchangeCode(t *testing.T) { + sub := "1234567890" + email := "testemail@mail.com" + resp := fmt.Sprintf(`{"sub":"%s", "email":"%s","email_verified":true}`, sub, email) + client := &http.Client{ //nolint:exhaustruct + Transport: mockClient(func(req *http.Request) (*http.Response, error) { + if req.Method == http.MethodPost { + return &http.Response{ //nolint:exhaustruct + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser( + strings.NewReader(`{"access_token":"fake", + "token_type":"bearer", + "expires_in":3600}`), + ), + }, nil + } + return &http.Response{ //nolint:exhaustruct + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(resp)), + }, nil + }), + } + + provider := NewGoogleProvider("client.id", "secret", "http://localhost") + ctx := context.WithValue(context.TODO(), oauth2.HTTPClient, client) + + info, err := provider.ExchangeCode(ctx, "") + require.NoError(t, err) + assert.Equal(t, "google", info.Provider) + assert.Equal(t, sub, info.ProviderID) + assert.Equal(t, email, info.Email) + assert.True(t, info.EmailVerified) +}
@@ -0,0 +1,24 @@
+package oauth + +import "context" + +// Provider is an OAuth interface. +type Provider interface { + // GetAuthURL return the provider's authorization page URL. + GetAuthURL(state string) string + + // ExchangeCode exchanges the provided authorization code for user information. + ExchangeCode(ctx context.Context, code string) (UserInfo, error) +} + +// UserInfo represents the user information returned by the OAuth provider. +type UserInfo struct { + // Provider is the name of the OAuth provider + Provider string + // ProviderID is user ID assigned by the provider + ProviderID string + // Email is user's email address returned by the provider + Email string + // EmailVerified indicates whether the email was verified by the provider + EmailVerified bool +}
@@ -0,0 +1,104 @@
+package usersrv + +import ( + "context" + "errors" + "log/slog" + "strings" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/olexsmir/onasty/internal/dtos" + "github.com/olexsmir/onasty/internal/models" + "github.com/olexsmir/onasty/internal/oauth" +) + +var ErrProviderNotSupported = errors.New("oauth2 provider not supported") + +const ( + googleProvider = "google" + githubProvider = "github" +) + +func (u *UserSrv) GetOAuthURL(providerName string) (string, error) { + switch providerName { + case googleProvider: + return u.googleOauth.GetAuthURL(""), nil + case githubProvider: + return u.githubOauth.GetAuthURL(""), nil + default: + return "", ErrProviderNotSupported + } +} + +func (u *UserSrv) HandleOAuthLogin( + ctx context.Context, + providerName, code string, +) (dtos.Tokens, error) { + userInfo, err := u.getUserInfoBasedOnProvider(ctx, providerName, code) + if err != nil { + return dtos.Tokens{}, err + } + + userID, err := u.getUserByOAuthIDOrCreateOne(ctx, userInfo) + if err != nil { + return dtos.Tokens{}, err + } + + if err = u.userstore.LinkOAuthIdentity(ctx, userID, userInfo.Provider, userInfo.ProviderID); err != nil { + slog.ErrorContext(ctx, "failed to link user identity", "user_id", userID, "err", err) + return dtos.Tokens{}, err + } + + tokens, err := u.issueTokens(ctx, userID) + + return tokens, err +} + +func (u *UserSrv) getUserInfoBasedOnProvider( + ctx context.Context, + providerName, code string, +) (oauth.UserInfo, error) { + var userInfo oauth.UserInfo + var err error + + switch providerName { + case googleProvider: + userInfo, err = u.googleOauth.ExchangeCode(ctx, code) + case githubProvider: + userInfo, err = u.githubOauth.ExchangeCode(ctx, code) + default: + return oauth.UserInfo{}, ErrProviderNotSupported + } + + return userInfo, err +} + +func getUsernameFromEmail(email string) string { + p := strings.Split(email, "@") + return p[0] +} + +func (u *UserSrv) getUserByOAuthIDOrCreateOne( + ctx context.Context, + info oauth.UserInfo, +) (uuid.UUID, error) { + user, err := u.userstore.GetByOAuthID(ctx, info.Provider, info.ProviderID) + if err != nil { + if errors.Is(err, models.ErrUserNotFound) { + uid, cerr := u.userstore.Create(ctx, models.User{ + ID: uuid.Nil, + Username: getUsernameFromEmail(info.Email), + Email: info.Email, + Activated: true, + Password: "", + CreatedAt: time.Now(), + LastLoginAt: time.Now(), + }) + return uid, cerr + } + return uuid.Nil, err + } + + return user.ID, nil +}
@@ -12,6 +12,7 @@ "github.com/olexsmir/onasty/internal/events/mailermq"
"github.com/olexsmir/onasty/internal/hasher" "github.com/olexsmir/onasty/internal/jwtutil" "github.com/olexsmir/onasty/internal/models" + "github.com/olexsmir/onasty/internal/oauth" "github.com/olexsmir/onasty/internal/store/psql/sessionrepo" "github.com/olexsmir/onasty/internal/store/psql/userepo" "github.com/olexsmir/onasty/internal/store/psql/vertokrepo"@@ -25,6 +26,9 @@ RefreshTokens(ctx context.Context, refreshToken string) (dtos.Tokens, error)
Logout(ctx context.Context, userID uuid.UUID) error ChangePassword(ctx context.Context, userID uuid.UUID, inp dtos.ChangeUserPassword) error + + GetOAuthURL(providerName string) (string, error) + HandleOAuthLogin(ctx context.Context, providerName, code string) (dtos.Tokens, error) Verify(ctx context.Context, verificationKey string) error ResendVerificationEmail(ctx context.Context, credentials dtos.SignIn) error@@ -45,6 +49,8 @@ hasher hasher.Hasher
jwtTokenizer jwtutil.JWTTokenizer mailermq mailermq.Mailer cache usercache.UserCacheer + googleOauth oauth.Provider + githubOauth oauth.Provider refreshTokenTTL time.Duration verificationTokenTTL time.Duration@@ -58,6 +64,7 @@ hasher hasher.Hasher,
jwtTokenizer jwtutil.JWTTokenizer, mailermq mailermq.Mailer, cache usercache.UserCacheer, + googleOauth, githubOauth oauth.Provider, refreshTokenTTL, verificationTokenTTL time.Duration, ) *UserSrv { return &UserSrv{@@ -68,6 +75,8 @@ hasher: hasher,
jwtTokenizer: jwtTokenizer, mailermq: mailermq, cache: cache, + googleOauth: googleOauth, + githubOauth: githubOauth, refreshTokenTTL: refreshTokenTTL, verificationTokenTTL: verificationTokenTTL, }@@ -135,19 +144,8 @@ if !user.IsActivated() {
return dtos.Tokens{}, models.ErrUserIsNotActivated } - tokens, err := u.createTokens(user.ID) - if err != nil { - return dtos.Tokens{}, err - } - - if err := u.sessionstore.Set(ctx, user.ID, tokens.Refresh, time.Now().Add(u.refreshTokenTTL)); err != nil { - return dtos.Tokens{}, err - } - - return dtos.Tokens{ - Access: tokens.Access, - Refresh: tokens.Refresh, - }, nil + tokens, err := u.issueTokens(ctx, user.ID) + return tokens, err } func (u *UserSrv) Logout(ctx context.Context, userID uuid.UUID) error {@@ -301,3 +299,16 @@ Access: accessToken,
Refresh: refreshToken, }, err } + +func (u UserSrv) issueTokens(ctx context.Context, userID uuid.UUID) (dtos.Tokens, error) { + toks, err := u.createTokens(userID) + if err != nil { + return dtos.Tokens{}, err + } + + if err := u.sessionstore.Set(ctx, userID, toks.Refresh, time.Now().Add(u.refreshTokenTTL)); err != nil { + return dtos.Tokens{}, err + } + + return toks, nil +}
@@ -29,6 +29,9 @@ // SetPassword sets new password for user by their id
// password should be hashed SetPassword(ctx context.Context, userID uuid.UUID, newPassword string) error + GetByOAuthID(ctx context.Context, provider, providerID string) (models.User, error) + LinkOAuthIdentity(ctx context.Context, userID uuid.UUID, provider, providerID string) error + CheckIfUserExists(ctx context.Context, userID uuid.UUID) (bool, error) CheckIfUserIsActivated(ctx context.Context, userID uuid.UUID) (bool, error) }@@ -111,6 +114,41 @@ return uuid.Nil, models.ErrUserNotFound
} return id, err +} + +func (r *UserRepo) GetByOAuthID( + ctx context.Context, + provider, providerID string, +) (models.User, error) { + query := `--sql + select u.id, u.username, u.email, u.password, u.activated, u.created_at, u.last_login_at + from users u + join oauth_identities oi on u.id = oi.user_id + where oi.provider = $1 + and oi.provider_id = $2 + limit 1` + + var user models.User + err := r.db.QueryRow(ctx, query, provider, providerID). + Scan(&user.ID, &user.Username, &user.Email, &user.Password, &user.Activated, &user.CreatedAt, &user.LastLoginAt) + if errors.Is(err, pgx.ErrNoRows) { + return models.User{}, models.ErrUserNotFound + } + + return user, err +} + +func (r *UserRepo) LinkOAuthIdentity( + ctx context.Context, + userID uuid.UUID, + provider, providerID string, +) error { + query := `--sql + insert into oauth_identities (user_id, provider, provider_id) + values ($1, $2, $3)` + + _, err := r.db.Exec(ctx, query, userID, provider, providerID) + return err } func (r *UserRepo) MarkUserAsActivated(ctx context.Context, id uuid.UUID) error {
@@ -36,6 +36,12 @@ {
authorized.POST("/logout", a.logOutHandler) authorized.POST("/change-password", a.changePasswordHandler) } + + oauth := r.Group("/oauth") + { + oauth.GET("/:provider", a.oauthLoginHandler) + oauth.GET("/:provider/callback", a.oauthCallbackHandler) + } } note := r.Group("/note", a.couldBeAuthorizedMiddleware)
@@ -153,3 +153,30 @@ }
c.Status(http.StatusOK) } + +func (a *APIV1) oauthLoginHandler(c *gin.Context) { + url, err := a.usersrv.GetOAuthURL(c.Param("provider")) + if err != nil { + errorResponse(c, err) + return + } + + c.Redirect(http.StatusSeeOther, url) +} + +func (a *APIV1) oauthCallbackHandler(c *gin.Context) { + tokens, err := a.usersrv.HandleOAuthLogin( + c.Request.Context(), + c.Param("provider"), + c.Query("code"), + ) + if err != nil { + errorResponse(c, err) + return + } + + c.JSON(http.StatusOK, signInResponse{ + AccessToken: tokens.Access, + RefreshToken: tokens.Refresh, + }) +}
@@ -7,6 +7,7 @@ "net/http"
"github.com/gin-gonic/gin" "github.com/olexsmir/onasty/internal/models" + "github.com/olexsmir/onasty/internal/service/usersrv" ) var ErrUnauthorized = errors.New("unauthorized")@@ -16,7 +17,8 @@ Message string `json:"message"`
} func errorResponse(c *gin.Context, err error) { - if errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || + if errors.Is(err, usersrv.ErrProviderNotSupported) || + errors.Is(err, models.ErrUserEmailIsAlreadyInUse) || errors.Is(err, models.ErrUsernameIsAlreadyInUse) || errors.Is(err, models.ErrUserIsAlreadyVerified) || errors.Is(err, models.ErrUserIsNotActivated) ||
@@ -3,6 +3,7 @@
import ( "context" "net/http" + "strconv" "time" )@@ -12,7 +13,7 @@ }
type Config struct { // Port http server port - Port string + Port int // ReadTimeout read timeout ReadTimeout time.Duration@@ -25,13 +26,28 @@ MaxHeaderSizeMb int
} func NewServer(handler http.Handler, cfg Config) *Server { + p := strconv.Itoa(cfg.Port) return &Server{ http: &http.Server{ - Addr: ":" + cfg.Port, + Addr: ":" + p, Handler: handler, ReadTimeout: cfg.ReadTimeout, WriteTimeout: cfg.WriteTimeout, MaxHeaderBytes: cfg.MaxHeaderSizeMb << 20, + }, + } +} + +// NewDefaultServer returns http server with default config +func NewDefaultServer(handler http.Handler, port int) *Server { + p := strconv.Itoa(port) + return &Server{ + http: &http.Server{ + Addr: ":" + p, + Handler: handler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, }, } }
@@ -1,6 +1,9 @@
package main -import "os" +import ( + "os" + "strconv" +) type Config struct { AppURL string@@ -14,7 +17,7 @@ LogFormat string
LogShowLine bool MetricsEnabled bool - MetricsPort string + MetricsPort int } func NewConfig() *Config {@@ -27,7 +30,7 @@ MailgunAPIKey: getenvOrDefault("MAILGUN_API_KEY", ""),
LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), LogFormat: getenvOrDefault("LOG_FORMAT", "json"), LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true", - MetricsPort: getenvOrDefault("METRICS_PORT", ""), + MetricsPort: mustGetenvOrDefaultInt("METRICS_PORT", 8001), MetricsEnabled: getenvOrDefault("METRICS_ENABLED", "true") == "true", } }@@ -38,3 +41,14 @@ return v
} return def } + +func mustGetenvOrDefaultInt(key string, def int) int { + if v, ok := os.LookupEnv(key); ok { + r, err := strconv.Atoi(v) + if err != nil { + panic(err) + } + return r + } + return def +}
@@ -0,0 +1,2 @@
+ALTER TABLE users + ALTER COLUMN PASSWORD SET NOT NULL;
@@ -0,0 +1,2 @@
+ALTER TABLE users + ALTER COLUMN PASSWORD DROP NOT NULL;
@@ -0,0 +1,3 @@
+DROP TABLE oauth_identities; + +DROP TYPE provider_enum;
@@ -0,0 +1,13 @@
+CREATE TYPE provider_enum AS ENUM ( + 'google', + 'github' +); + +CREATE TABLE oauth_identities ( + id uuid PRIMARY KEY DEFAULT uuid_generate_v4 (), + user_id uuid REFERENCES users (id) ON DELETE CASCADE, + provider provider_enum NOT NULL, + provider_id varchar(50), + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (provider, provider_id) +);