all repos

onasty @ 86dba8ea0dc14c9a08dbb0baa9b0441168638328

a one-time notes service

onasty/internal/transport/http/ratelimit/ratelimit_test.go (view raw)

Smirnov Oleksandr Smirnov Oleksandr
ss2316544@gmail.com
refactor: fix annoyances (#97)..., 1 year ago
1
package ratelimit
2
3
import (
4
	"net/http"
5
	"net/http/httptest"
6
	"testing"
7
	"time"
8
9
	"github.com/gin-gonic/gin"
10
	"github.com/stretchr/testify/assert"
11
)
12
13
func TestRateLimiter_getVisitor(t *testing.T) {
14
	limiter := newLimiter(10, 20, time.Second)
15
	ip := visitorIP("127.0.0.1")
16
17
	visitor := limiter.getVisitor(ip)
18
	assert.NotNil(t, visitor)
19
20
	visitorAgain := limiter.getVisitor(ip)
21
	assert.Equal(t, visitor, visitorAgain)
22
23
	assert.Len(t, limiter.visitors, 1)
24
}
25
26
// TODO: rewrite to use "testing/synctest" when it gets merged
27
func TestRateLimiter_cleanupVisitors(t *testing.T) {
28
	limiter := newLimiter(10, 20, time.Second/2)
29
	limiter.getVisitor("192.168.9.1")
30
	assert.Len(t, limiter.visitors, 1)
31
32
	time.Sleep(time.Second)
33
	limiter.cleanupVisitors()
34
	assert.Empty(t, limiter.visitors)
35
}
36
37
func TestMiddleware(t *testing.T) {
38
	gin.SetMode(gin.TestMode)
39
	tests := map[string]struct {
40
		config       Config
41
		requests     int
42
		expectedCode int
43
	}{
44
		"allows requests with in limit": {
45
			config: Config{
46
				RPS:   2,
47
				Burst: 2,
48
				TTL:   time.Minute,
49
			},
50
			requests:     1,
51
			expectedCode: http.StatusOK,
52
		},
53
		"blocks requests over limit": {
54
			config: Config{
55
				RPS:   1,
56
				Burst: 1,
57
				TTL:   time.Minute,
58
			},
59
			requests:     2,
60
			expectedCode: http.StatusTooManyRequests,
61
		},
62
		"allows burst requests": {
63
			config: Config{
64
				RPS:   1,
65
				Burst: 3,
66
				TTL:   time.Minute,
67
			},
68
			requests:     3,
69
			expectedCode: http.StatusOK,
70
		},
71
	}
72
73
	for name, tt := range tests {
74
		t.Run(name, func(t *testing.T) {
75
			handler := MiddlewareWithConfig(tt.config)
76
			var lastCode int
77
78
			for range tt.requests {
79
				w := httptest.NewRecorder()
80
				c, _ := gin.CreateTestContext(w)
81
				c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
82
83
				handler(c)
84
				lastCode = w.Code
85
			}
86
87
			assert.Equal(t, tt.expectedCode, lastCode)
88
		})
89
	}
90
}