all repos

onasty @ fa8001ef7b978b20df8dba33a04af109a2b5b18e

a one-time notes service

onasty/internal/transport/http/ratelimit/ratelimit.go (view raw)

Smirnov Oleksandr Smirnov Oleksandr
ss2316544@gmail.com
feat: rate limiting (#36)..., 1 year ago
1
// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests
2
3
package ratelimit
4
5
import (
6
	"net/http"
7
	"sync"
8
	"time"
9
10
	"github.com/gin-gonic/gin"
11
	"golang.org/x/time/rate"
12
)
13
14
type (
15
	rateLimiter struct {
16
		mu sync.RWMutex
17
18
		visitors map[visitorIP]*visitor
19
20
		// limit is the maximum number of requests per second
21
		limit rate.Limit
22
23
		// ttl is the time after which a visitor is forgotten
24
		ttl time.Duration
25
26
		// burst is the maximum number of requests that can be made in a short amount of time
27
		burst int
28
	}
29
30
	visitorIP string
31
	visitor   struct {
32
		limiter  *rate.Limiter
33
		lastSeen time.Time
34
	}
35
)
36
37
func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter {
38
	return &rateLimiter{ //nolint:exhaustruct
39
		visitors: make(map[visitorIP]*visitor),
40
		limit:    rate.Limit(rps),
41
		burst:    burst,
42
		ttl:      ttl,
43
	}
44
}
45
46
// Retrieve and return the rate limiter for the current visitor if it
47
// already exists. Otherwise create a new rate limiter and add it to
48
// the visitors map, using the IP address as the key.
49
func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter {
50
	r.mu.RLock()
51
	v, exists := r.visitors[ip]
52
	r.mu.RUnlock()
53
54
	if !exists {
55
		limit := rate.NewLimiter(r.limit, r.burst)
56
57
		r.mu.Lock()
58
		r.visitors[ip] = &visitor{
59
			limiter:  limit,
60
			lastSeen: time.Now(),
61
		}
62
		r.mu.Unlock()
63
64
		return limit
65
	}
66
67
	r.mu.Lock()
68
	v.lastSeen = time.Now()
69
	r.mu.Unlock()
70
71
	return v.limiter
72
}
73
74
// Every minute check the map for visitors that haven't been seen for
75
// more than 3 minutes and delete the entries.
76
func (r *rateLimiter) cleanupVisitors() {
77
	for {
78
		time.Sleep(time.Minute)
79
80
		r.mu.Lock()
81
		for ip, v := range r.visitors {
82
			if time.Since(v.lastSeen) > r.ttl {
83
				delete(r.visitors, ip)
84
			}
85
		}
86
		r.mu.Unlock()
87
	}
88
}
89
90
type Config struct {
91
	// RPS is the maximum number of requests per second
92
	RPS int
93
94
	// TTL is the time after which a visitor is forgotten
95
	TTL time.Duration
96
97
	// Burst is the maximum number of requests that can be made in a short amount of time
98
	Burst int
99
}
100
101
// MiddlewareWithConfig returns a new rate limiting middleware with the given config
102
func MiddlewareWithConfig(c Config) gin.HandlerFunc {
103
	lmt := newLimiter(c.RPS, c.Burst, c.TTL)
104
	go lmt.cleanupVisitors()
105
106
	return func(c *gin.Context) {
107
		visitor := lmt.getVisitor(visitorIP(c.ClientIP()))
108
		if visitor == nil {
109
			c.AbortWithStatus(http.StatusInternalServerError)
110
			return
111
		}
112
113
		if !visitor.Allow() {
114
			c.AbortWithStatus(http.StatusTooManyRequests)
115
			return
116
		}
117
118
		c.Next()
119
	}
120
}