7 files changed,
178 insertions(+),
8 deletions(-)
Author:
Smirnov Oleksandr
ss2316544@gmail.com
Committed by:
GitHub
noreply@github.com
Committed at:
2024-10-18 23:58:44 +0300
Parent:
47d33af
M
cmd/server/main.go
··· 25 25 "github.com/olexsmir/onasty/internal/store/psqlutil" 26 26 httptransport "github.com/olexsmir/onasty/internal/transport/http" 27 27 "github.com/olexsmir/onasty/internal/transport/http/httpserver" 28 + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" 28 29 ) 29 30 30 31 func main() { ··· 83 84 noterepo := noterepo.New(psqlDB) 84 85 notesrv := notesrv.New(noterepo) 85 86 86 - handler := httptransport.NewTransport(usersrv, notesrv) 87 + rateLimiterConfig := ratelimit.Config{ 88 + RPS: cfg.RateLimiterRPS, 89 + TTL: cfg.RateLimiterTTL, 90 + Burst: cfg.RateLimiterBurst, 91 + } 92 + 93 + handler := httptransport.NewTransport( 94 + usersrv, 95 + notesrv, 96 + rateLimiterConfig, 97 + ) 87 98 88 99 // http server 89 100 srv := httpserver.NewServer(cfg.ServerPort, handler.Handler())
M
e2e/e2e_test.go
··· 26 26 "github.com/olexsmir/onasty/internal/store/psql/vertokrepo" 27 27 "github.com/olexsmir/onasty/internal/store/psqlutil" 28 28 httptransport "github.com/olexsmir/onasty/internal/transport/http" 29 + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" 29 30 "github.com/stretchr/testify/require" 30 31 "github.com/stretchr/testify/suite" 31 32 "github.com/testcontainers/testcontainers-go" ··· 117 118 noterepo := noterepo.New(e.postgresDB) 118 119 notesrv := notesrv.New(noterepo) 119 120 120 - handler := httptransport.NewTransport(usersrv, notesrv) 121 + // for testing purposes, it's ok to have high values ig 122 + ratelimitCfg := ratelimit.Config{ 123 + RPS: 1000, 124 + TTL: time.Millisecond, 125 + Burst: 1000, 126 + } 127 + 128 + handler := httptransport.NewTransport(usersrv, notesrv, ratelimitCfg) 121 129 e.router = handler.Handler() 122 130 } 123 131
M
internal/config/config.go
··· 3 3 import ( 4 4 "errors" 5 5 "os" 6 + "strconv" 6 7 "time" 7 8 ) 8 9 ··· 28 29 LogLevel string 29 30 LogFormat string 30 31 LogShowLine bool 32 + 33 + RateLimiterRPS int 34 + RateLimiterBurst int 35 + RateLimiterTTL time.Duration 31 36 } 32 37 33 38 func NewConfig() *Config { ··· 39 44 PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""), 40 45 41 46 JwtSigningKey: getenvOrDefault("JWT_SIGNING_KEY", ""), 42 - JwtAccessTokenTTL: mustParseDurationOrPanic( 47 + JwtAccessTokenTTL: mustParseDuration( 43 48 getenvOrDefault("JWT_ACCESS_TOKEN_TTL", "15m"), 44 49 ), 45 - JwtRefreshTokenTTL: mustParseDurationOrPanic( 50 + JwtRefreshTokenTTL: mustParseDuration( 46 51 getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"), 47 52 ), 48 53 49 54 MailgunFrom: getenvOrDefault("MAILGUN_FROM", ""), 50 55 MailgunDomain: getenvOrDefault("MAILGUN_DOMAIN", ""), 51 56 MailgunAPIKey: getenvOrDefault("MAILGUN_API_KEY", ""), 52 - VerificationTokenTTL: mustParseDurationOrPanic( 57 + VerificationTokenTTL: mustParseDuration( 53 58 getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"), 54 59 ), 55 60 ··· 59 64 LogLevel: getenvOrDefault("LOG_LEVEL", "debug"), 60 65 LogFormat: getenvOrDefault("LOG_FORMAT", "json"), 61 66 LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true", 67 + 68 + RateLimiterRPS: mustGetenvOrDefaultInt("RATELIMITER_RPS", 100), 69 + RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10), 70 + RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")), 62 71 } 63 72 } 64 73 ··· 73 82 return def 74 83 } 75 84 76 -func mustParseDurationOrPanic(dur string) time.Duration { 85 +func mustGetenvOrDefaultInt(key string, def int) int { 86 + if v, ok := os.LookupEnv(key); ok { 87 + r, err := strconv.Atoi(v) 88 + if err != nil { 89 + panic(err) 90 + } 91 + return r 92 + } 93 + return def 94 +} 95 + 96 +func mustParseDuration(dur string) time.Duration { 77 97 d, err := time.ParseDuration(dur) 78 98 if err != nil { 79 99 panic(errors.Join(errors.New("cannot time.ParseDuration"), err)) //nolint:err113
M
internal/transport/http/http.go
··· 7 7 "github.com/olexsmir/onasty/internal/service/notesrv" 8 8 "github.com/olexsmir/onasty/internal/service/usersrv" 9 9 "github.com/olexsmir/onasty/internal/transport/http/apiv1" 10 + "github.com/olexsmir/onasty/internal/transport/http/ratelimit" 10 11 "github.com/olexsmir/onasty/internal/transport/http/reqid" 11 12 ) 12 13 13 14 type Transport struct { 14 15 usersrv usersrv.UserServicer 15 16 notesrv notesrv.NoteServicer 17 + 18 + ratelimitCfg ratelimit.Config 16 19 } 17 20 18 21 func NewTransport( 19 22 us usersrv.UserServicer, 20 23 ns notesrv.NoteServicer, 24 + ratelimitCfg ratelimit.Config, 21 25 ) *Transport { 22 26 return &Transport{ 23 - usersrv: us, 24 - notesrv: ns, 27 + usersrv: us, 28 + notesrv: ns, 29 + ratelimitCfg: ratelimitCfg, 25 30 } 26 31 } 27 32 ··· 31 36 gin.Recovery(), 32 37 reqid.Middleware(), 33 38 t.logger(), 39 + ratelimit.MiddlewareWithConfig(t.ratelimitCfg), 34 40 ) 35 41 36 42 api := r.Group("/api")
A
internal/transport/http/ratelimit/ratelimit.go
··· 1 +// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests 2 + 3 +package ratelimit 4 + 5 +import ( 6 + "net/http" 7 + "sync" 8 + "time" 9 + 10 + "github.com/gin-gonic/gin" 11 + "golang.org/x/time/rate" 12 +) 13 + 14 +type ( 15 + rateLimiter struct { 16 + mu sync.RWMutex 17 + 18 + visitors map[visitorIP]*visitor 19 + 20 + // limit is the maximum number of requests per second 21 + limit rate.Limit 22 + 23 + // ttl is the time after which a visitor is forgotten 24 + ttl time.Duration 25 + 26 + // burst is the maximum number of requests that can be made in a short amount of time 27 + burst int 28 + } 29 + 30 + visitorIP string 31 + visitor struct { 32 + limiter *rate.Limiter 33 + lastSeen time.Time 34 + } 35 +) 36 + 37 +func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter { 38 + return &rateLimiter{ //nolint:exhaustruct 39 + visitors: make(map[visitorIP]*visitor), 40 + limit: rate.Limit(rps), 41 + burst: burst, 42 + ttl: ttl, 43 + } 44 +} 45 + 46 +// Retrieve and return the rate limiter for the current visitor if it 47 +// already exists. Otherwise create a new rate limiter and add it to 48 +// the visitors map, using the IP address as the key. 49 +func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter { 50 + r.mu.RLock() 51 + v, exists := r.visitors[ip] 52 + r.mu.RUnlock() 53 + 54 + if !exists { 55 + limit := rate.NewLimiter(r.limit, r.burst) 56 + 57 + r.mu.Lock() 58 + r.visitors[ip] = &visitor{ 59 + limiter: limit, 60 + lastSeen: time.Now(), 61 + } 62 + r.mu.Unlock() 63 + 64 + return limit 65 + } 66 + 67 + r.mu.Lock() 68 + v.lastSeen = time.Now() 69 + r.mu.Unlock() 70 + 71 + return v.limiter 72 +} 73 + 74 +// Every minute check the map for visitors that haven't been seen for 75 +// more than 3 minutes and delete the entries. 76 +func (r *rateLimiter) cleanupVisitors() { 77 + for { 78 + time.Sleep(time.Minute) 79 + 80 + r.mu.Lock() 81 + for ip, v := range r.visitors { 82 + if time.Since(v.lastSeen) > r.ttl { 83 + delete(r.visitors, ip) 84 + } 85 + } 86 + r.mu.Unlock() 87 + } 88 +} 89 + 90 +type Config struct { 91 + // RPS is the maximum number of requests per second 92 + RPS int 93 + 94 + // TTL is the time after which a visitor is forgotten 95 + TTL time.Duration 96 + 97 + // Burst is the maximum number of requests that can be made in a short amount of time 98 + Burst int 99 +} 100 + 101 +// MiddlewareWithConfig returns a new rate limiting middleware with the given config 102 +func MiddlewareWithConfig(c Config) gin.HandlerFunc { 103 + lmt := newLimiter(c.RPS, c.Burst, c.TTL) 104 + go lmt.cleanupVisitors() 105 + 106 + return func(c *gin.Context) { 107 + visitor := lmt.getVisitor(visitorIP(c.ClientIP())) 108 + if visitor == nil { 109 + c.AbortWithStatus(http.StatusInternalServerError) 110 + return 111 + } 112 + 113 + if !visitor.Allow() { 114 + c.AbortWithStatus(http.StatusTooManyRequests) 115 + return 116 + } 117 + 118 + c.Next() 119 + } 120 +}