onasty/internal/transport/http/ratelimit/ratelimit_test.go (view raw)
| 1 | package ratelimit |
| 2 | |
| 3 | import ( |
| 4 | "net/http" |
| 5 | "net/http/httptest" |
| 6 | "testing" |
| 7 | "testing/synctest" |
| 8 | "time" |
| 9 | |
| 10 | "github.com/gin-gonic/gin" |
| 11 | "github.com/stretchr/testify/assert" |
| 12 | ) |
| 13 | |
| 14 | func TestRateLimiter_getVisitor(t *testing.T) { |
| 15 | limiter := newLimiter(10, 20, time.Second) |
| 16 | ip := visitorIP("127.0.0.1") |
| 17 | |
| 18 | visitor := limiter.getVisitor(ip) |
| 19 | assert.NotNil(t, visitor) |
| 20 | |
| 21 | visitorAgain := limiter.getVisitor(ip) |
| 22 | assert.Equal(t, visitor, visitorAgain) |
| 23 | |
| 24 | assert.Len(t, limiter.visitors, 1) |
| 25 | } |
| 26 | |
| 27 | func TestRateLimiter_cleanupVisitors(t *testing.T) { |
| 28 | synctest.Test(t, func(t *testing.T) { |
| 29 | limiter := newLimiter(10, 20, time.Minute) |
| 30 | limiter.getVisitor("192.168.9.1") |
| 31 | assert.Len(t, limiter.visitors, 1) |
| 32 | |
| 33 | time.Sleep(61 * time.Second) |
| 34 | |
| 35 | limiter.cleanupVisitors() |
| 36 | assert.Empty(t, limiter.visitors) |
| 37 | }) |
| 38 | } |
| 39 | |
| 40 | func TestMiddleware(t *testing.T) { |
| 41 | gin.SetMode(gin.TestMode) |
| 42 | tests := map[string]struct { |
| 43 | config Config |
| 44 | requests int |
| 45 | expectedCode int |
| 46 | }{ |
| 47 | "allows requests with in limit": { |
| 48 | config: Config{ |
| 49 | RPS: 2, |
| 50 | Burst: 2, |
| 51 | TTL: time.Minute, |
| 52 | }, |
| 53 | requests: 1, |
| 54 | expectedCode: http.StatusOK, |
| 55 | }, |
| 56 | "blocks requests over limit": { |
| 57 | config: Config{ |
| 58 | RPS: 1, |
| 59 | Burst: 1, |
| 60 | TTL: time.Minute, |
| 61 | }, |
| 62 | requests: 2, |
| 63 | expectedCode: http.StatusTooManyRequests, |
| 64 | }, |
| 65 | "allows burst requests": { |
| 66 | config: Config{ |
| 67 | RPS: 1, |
| 68 | Burst: 3, |
| 69 | TTL: time.Minute, |
| 70 | }, |
| 71 | requests: 3, |
| 72 | expectedCode: http.StatusOK, |
| 73 | }, |
| 74 | } |
| 75 | |
| 76 | for name, tt := range tests { |
| 77 | t.Run(name, func(t *testing.T) { |
| 78 | handler := MiddlewareWithConfig(tt.config) |
| 79 | var lastCode int |
| 80 | |
| 81 | for range tt.requests { |
| 82 | w := httptest.NewRecorder() |
| 83 | c, _ := gin.CreateTestContext(w) |
| 84 | c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| 85 | |
| 86 | handler(c) |
| 87 | lastCode = w.Code |
| 88 | } |
| 89 | |
| 90 | assert.Equal(t, tt.expectedCode, lastCode) |
| 91 | }) |
| 92 | } |
| 93 | } |