onasty/internal/transport/http/ratelimit/ratelimit.go(view raw)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests
package ratelimit
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type (
rateLimiter struct {
mu sync.RWMutex
visitors map[visitorIP]*visitor
// limit is the maximum number of requests per second
limit rate.Limit
// ttl is the time after which a visitor is forgotten
ttl time.Duration
// burst is the maximum number of requests that can be made in a short amount of time
burst int
}
visitorIP string
visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
)
func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter {
return &rateLimiter{ //nolint:exhaustruct
visitors: make(map[visitorIP]*visitor),
limit: rate.Limit(rps),
burst: burst,
ttl: ttl,
}
}
// Retrieve and return the rate limiter for the current visitor if it
// already exists. Otherwise create a new rate limiter and add it to
// the visitors map, using the IP address as the key.
func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter {
r.mu.RLock()
v, exists := r.visitors[ip]
r.mu.RUnlock()
if !exists {
limit := rate.NewLimiter(r.limit, r.burst)
r.mu.Lock()
r.visitors[ip] = &visitor{
limiter: limit,
lastSeen: time.Now(),
}
r.mu.Unlock()
return limit
}
r.mu.Lock()
v.lastSeen = time.Now()
r.mu.Unlock()
return v.limiter
}
// Every minute check the map for visitors that haven't been seen for
// more than 3 minutes and delete the entries.
func (r *rateLimiter) cleanupVisitors() {
for {
time.Sleep(time.Minute)
r.mu.Lock()
for ip, v := range r.visitors {
if time.Since(v.lastSeen) > r.ttl {
delete(r.visitors, ip)
}
}
r.mu.Unlock()
}
}
type Config struct {
// RPS is the maximum number of requests per second
RPS int
// TTL is the time after which a visitor is forgotten
TTL time.Duration
// Burst is the maximum number of requests that can be made in a short amount of time
Burst int
}
// MiddlewareWithConfig returns a new rate limiting middleware with the given config
func MiddlewareWithConfig(c Config) gin.HandlerFunc {
lmt := newLimiter(c.RPS, c.Burst, c.TTL)
go lmt.cleanupVisitors()
return func(c *gin.Context) {
visitor := lmt.getVisitor(visitorIP(c.ClientIP()))
if visitor == nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if !visitor.Allow() {
c.AbortWithStatus(http.StatusTooManyRequests)
return
}
c.Next()
}
}
|