all repos

onasty @ e2c4fd283ecd1a92895e34c43cf02a3a94d2b194

a one-time notes service
7 files changed, 178 insertions(+), 8 deletions(-)
feat: rate limiting (#36)

* feat(ratelimit): add rate limiting module

* feat(ratelimit): setup rate limiter

* fixup! feat(ratelimit): setup rate limiter

* fixup! fixup! feat(ratelimit): setup rate limiter

* fixup! feat(ratelimit): setup rate limiter

* fix(e2e): set ratelimiter config in tests so they would pass

* refactor(ratelimit): remove unused code

* fix(ratelimit): now the middleware shouldn't panic

* fix(ratelimit): actually handle if user isnt found

* refactor(ratelimit): use rw mutex

* chore: update .env.example
Author: Smirnov Oleksandr ss2316544@gmail.com
Committed by: GitHub noreply@github.com
Committed at: 2024-10-18 23:58:44 +0300
Parent: 47d33af
M .env.example

@@ -24,3 +24,7 @@ MAILGUN_FROM=onasty@mail.com

MAILGUN_DOMAI='<domain>' MAILGUN_API_KEY='<token>' VERIFICATION_TOKEN_TTL=48h + +RATELIMITER_RPS=100 +RATELIMITER_BURST=10 +RATELIMITER_TTL=3m
M cmd/server/main.go

@@ -25,6 +25,7 @@ "github.com/olexsmir/onasty/internal/store/psql/vertokrepo"

"github.com/olexsmir/onasty/internal/store/psqlutil" httptransport "github.com/olexsmir/onasty/internal/transport/http" "github.com/olexsmir/onasty/internal/transport/http/httpserver" + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" ) func main() {

@@ -83,7 +84,17 @@

noterepo := noterepo.New(psqlDB) notesrv := notesrv.New(noterepo) - handler := httptransport.NewTransport(usersrv, notesrv) + rateLimiterConfig := ratelimit.Config{ + RPS: cfg.RateLimiterRPS, + TTL: cfg.RateLimiterTTL, + Burst: cfg.RateLimiterBurst, + } + + handler := httptransport.NewTransport( + usersrv, + notesrv, + rateLimiterConfig, + ) // http server srv := httpserver.NewServer(cfg.ServerPort, handler.Handler())
M e2e/e2e_test.go

@@ -26,6 +26,7 @@ "github.com/olexsmir/onasty/internal/store/psql/userepo"

"github.com/olexsmir/onasty/internal/store/psql/vertokrepo" "github.com/olexsmir/onasty/internal/store/psqlutil" httptransport "github.com/olexsmir/onasty/internal/transport/http" + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go"

@@ -117,7 +118,14 @@

noterepo := noterepo.New(e.postgresDB) notesrv := notesrv.New(noterepo) - handler := httptransport.NewTransport(usersrv, notesrv) + // for testing purposes, it's ok to have high values ig + ratelimitCfg := ratelimit.Config{ + RPS: 1000, + TTL: time.Millisecond, + Burst: 1000, + } + + handler := httptransport.NewTransport(usersrv, notesrv, ratelimitCfg) e.router = handler.Handler() }
M go.mod

@@ -15,6 +15,7 @@ github.com/prometheus/client_golang v1.20.5

github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.33.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.33.0 + golang.org/x/time v0.5.0 ) require (
M internal/config/config.go

@@ -3,6 +3,7 @@

import ( "errors" "os" + "strconv" "time" )

@@ -28,6 +29,10 @@

LogLevel string LogFormat string LogShowLine bool + + RateLimiterRPS int + RateLimiterBurst int + RateLimiterTTL time.Duration } func NewConfig() *Config {

@@ -39,17 +44,17 @@ PostgresDSN: getenvOrDefault("POSTGRESQL_DSN", ""),

PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""), JwtSigningKey: getenvOrDefault("JWT_SIGNING_KEY", ""), - JwtAccessTokenTTL: mustParseDurationOrPanic( + JwtAccessTokenTTL: mustParseDuration( getenvOrDefault("JWT_ACCESS_TOKEN_TTL", "15m"), ), - JwtRefreshTokenTTL: mustParseDurationOrPanic( + JwtRefreshTokenTTL: mustParseDuration( getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"), ), MailgunFrom: getenvOrDefault("MAILGUN_FROM", ""), MailgunDomain: getenvOrDefault("MAILGUN_DOMAIN", ""), MailgunAPIKey: getenvOrDefault("MAILGUN_API_KEY", ""), - VerificationTokenTTL: mustParseDurationOrPanic( + VerificationTokenTTL: mustParseDuration( getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"), ),

@@ -59,6 +64,10 @@

LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), LogFormat: getenvOrDefault("LOG_FORMAT", "json"), LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true", + + RateLimiterRPS: mustGetenvOrDefaultInt("RATELIMITER_RPS", 100), + RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10), + RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")), } }

@@ -73,7 +82,18 @@ }

return def } -func mustParseDurationOrPanic(dur string) time.Duration { +func mustGetenvOrDefaultInt(key string, def int) int { + if v, ok := os.LookupEnv(key); ok { + r, err := strconv.Atoi(v) + if err != nil { + panic(err) + } + return r + } + return def +} + +func mustParseDuration(dur string) time.Duration { d, err := time.ParseDuration(dur) if err != nil { panic(errors.Join(errors.New("cannot time.ParseDuration"), err)) //nolint:err113
M internal/transport/http/http.go

@@ -7,21 +7,26 @@ "github.com/gin-gonic/gin"

"github.com/olexsmir/onasty/internal/service/notesrv" "github.com/olexsmir/onasty/internal/service/usersrv" "github.com/olexsmir/onasty/internal/transport/http/apiv1" + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" "github.com/olexsmir/onasty/internal/transport/http/reqid" ) type Transport struct { usersrv usersrv.UserServicer notesrv notesrv.NoteServicer + + ratelimitCfg ratelimit.Config } func NewTransport( us usersrv.UserServicer, ns notesrv.NoteServicer, + ratelimitCfg ratelimit.Config, ) *Transport { return &Transport{ - usersrv: us, - notesrv: ns, + usersrv: us, + notesrv: ns, + ratelimitCfg: ratelimitCfg, } }

@@ -31,6 +36,7 @@ r.Use(

gin.Recovery(), reqid.Middleware(), t.logger(), + ratelimit.MiddlewareWithConfig(t.ratelimitCfg), ) api := r.Group("/api")
A internal/transport/http/ratelimit/ratelimit.go

@@ -0,0 +1,120 @@

+// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests + +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type ( + rateLimiter struct { + mu sync.RWMutex + + visitors map[visitorIP]*visitor + + // limit is the maximum number of requests per second + limit rate.Limit + + // ttl is the time after which a visitor is forgotten + ttl time.Duration + + // burst is the maximum number of requests that can be made in a short amount of time + burst int + } + + visitorIP string + visitor struct { + limiter *rate.Limiter + lastSeen time.Time + } +) + +func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter { + return &rateLimiter{ //nolint:exhaustruct + visitors: make(map[visitorIP]*visitor), + limit: rate.Limit(rps), + burst: burst, + ttl: ttl, + } +} + +// Retrieve and return the rate limiter for the current visitor if it +// already exists. Otherwise create a new rate limiter and add it to +// the visitors map, using the IP address as the key. +func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter { + r.mu.RLock() + v, exists := r.visitors[ip] + r.mu.RUnlock() + + if !exists { + limit := rate.NewLimiter(r.limit, r.burst) + + r.mu.Lock() + r.visitors[ip] = &visitor{ + limiter: limit, + lastSeen: time.Now(), + } + r.mu.Unlock() + + return limit + } + + r.mu.Lock() + v.lastSeen = time.Now() + r.mu.Unlock() + + return v.limiter +} + +// Every minute check the map for visitors that haven't been seen for +// more than 3 minutes and delete the entries. +func (r *rateLimiter) cleanupVisitors() { + for { + time.Sleep(time.Minute) + + r.mu.Lock() + for ip, v := range r.visitors { + if time.Since(v.lastSeen) > r.ttl { + delete(r.visitors, ip) + } + } + r.mu.Unlock() + } +} + +type Config struct { + // RPS is the maximum number of requests per second + RPS int + + // TTL is the time after which a visitor is forgotten + TTL time.Duration + + // Burst is the maximum number of requests that can be made in a short amount of time + Burst int +} + +// MiddlewareWithConfig returns a new rate limiting middleware with the given config +func MiddlewareWithConfig(c Config) gin.HandlerFunc { + lmt := newLimiter(c.RPS, c.Burst, c.TTL) + go lmt.cleanupVisitors() + + return func(c *gin.Context) { + visitor := lmt.getVisitor(visitorIP(c.ClientIP())) + if visitor == nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + if !visitor.Allow() { + c.AbortWithStatus(http.StatusTooManyRequests) + return + } + + c.Next() + } +}