all repos

onasty @ 3b5e67f

a one-time notes service

onasty/internal/transport/http/ratelimit/ratelimit.go (view raw)

Smirnov Oleksandr Smirnov Oleksandr
ss2316544@gmail.com
refactor: fix annoyances (#97)..., 1 year ago
1
// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests
2
3
package ratelimit
4
5
import (
6
	"net/http"
7
	"sync"
8
	"time"
9
10
	"github.com/gin-gonic/gin"
11
	"golang.org/x/time/rate"
12
)
13
14
type (
15
	rateLimiter struct {
16
		mu sync.RWMutex
17
18
		visitors map[visitorIP]*visitor
19
20
		// limit is the maximum number of requests per second
21
		limit rate.Limit
22
23
		// ttl is the time after which a visitor is forgotten
24
		ttl time.Duration
25
26
		// burst is the maximum number of requests that can be made in a short amount of time
27
		burst int
28
	}
29
30
	visitorIP string
31
	visitor   struct {
32
		limiter  *rate.Limiter
33
		lastSeen time.Time
34
	}
35
)
36
37
func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter {
38
	return &rateLimiter{ //nolint:exhaustruct
39
		visitors: make(map[visitorIP]*visitor),
40
		limit:    rate.Limit(rps),
41
		burst:    burst,
42
		ttl:      ttl,
43
	}
44
}
45
46
// getVisitor Retrieve and return the rate limiter for the current visitor
47
// if it already exists. Otherwise create a new rate limiter and add it to
48
// the visitors map, using the IP address as the key.
49
func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter {
50
	r.mu.RLock()
51
	v, exists := r.visitors[ip]
52
	r.mu.RUnlock()
53
54
	if !exists {
55
		limit := rate.NewLimiter(r.limit, r.burst)
56
57
		r.mu.Lock()
58
		r.visitors[ip] = &visitor{
59
			limiter:  limit,
60
			lastSeen: time.Now(),
61
		}
62
		r.mu.Unlock()
63
64
		return limit
65
	}
66
67
	r.mu.Lock()
68
	v.lastSeen = time.Now()
69
	r.mu.Unlock()
70
71
	return v.limiter
72
}
73
74
// cleanUpVisitors checks the map of visitors that haven't been seed
75
// for more than [Config].TTL and delete those entries
76
func (r *rateLimiter) cleanupVisitors() {
77
	r.mu.Lock()
78
	defer r.mu.Unlock()
79
80
	for ip, v := range r.visitors {
81
		if time.Since(v.lastSeen) > r.ttl {
82
			delete(r.visitors, ip)
83
		}
84
	}
85
}
86
87
// cleanupVisitorsLoop runs [rateLimiter.cleanupVisitors] every minute
88
func (r *rateLimiter) cleanupVisitorsLoop() {
89
	for {
90
		time.Sleep(time.Minute)
91
		r.cleanupVisitors()
92
	}
93
}
94
95
type Config struct {
96
	// RPS is the maximum number of requests per second
97
	RPS int
98
99
	// TTL is the time after which a visitor is forgotten
100
	TTL time.Duration
101
102
	// Burst is the maximum number of requests that can be made in a short amount of time
103
	Burst int
104
}
105
106
// MiddlewareWithConfig returns a new rate limiting middleware with the given config
107
func MiddlewareWithConfig(c Config) gin.HandlerFunc {
108
	lmt := newLimiter(c.RPS, c.Burst, c.TTL)
109
	go lmt.cleanupVisitorsLoop()
110
111
	return func(c *gin.Context) {
112
		visitor := lmt.getVisitor(visitorIP(c.ClientIP()))
113
		if visitor == nil {
114
			c.AbortWithStatus(http.StatusInternalServerError)
115
			return
116
		}
117
118
		if !visitor.Allow() {
119
			c.AbortWithStatus(http.StatusTooManyRequests)
120
			return
121
		}
122
123
		c.Next()
124
	}
125
}