onasty/internal/transport/http/ratelimit/ratelimit_test.go (view raw)
Smirnov Oleksandr
Smirnov Oleksandr
ss2316544@gmail.com refactor: fix annoyances (#97)..., 1 year ago
ss2316544@gmail.com refactor: fix annoyances (#97)..., 1 year ago
| 1 | package ratelimit |
| 2 | |
| 3 | import ( |
| 4 | "net/http" |
| 5 | "net/http/httptest" |
| 6 | "testing" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/gin-gonic/gin" |
| 10 | "github.com/stretchr/testify/assert" |
| 11 | ) |
| 12 | |
| 13 | func TestRateLimiter_getVisitor(t *testing.T) { |
| 14 | limiter := newLimiter(10, 20, time.Second) |
| 15 | ip := visitorIP("127.0.0.1") |
| 16 | |
| 17 | visitor := limiter.getVisitor(ip) |
| 18 | assert.NotNil(t, visitor) |
| 19 | |
| 20 | visitorAgain := limiter.getVisitor(ip) |
| 21 | assert.Equal(t, visitor, visitorAgain) |
| 22 | |
| 23 | assert.Len(t, limiter.visitors, 1) |
| 24 | } |
| 25 | |
| 26 | // TODO: rewrite to use "testing/synctest" when it gets merged |
| 27 | func TestRateLimiter_cleanupVisitors(t *testing.T) { |
| 28 | limiter := newLimiter(10, 20, time.Second/2) |
| 29 | limiter.getVisitor("192.168.9.1") |
| 30 | assert.Len(t, limiter.visitors, 1) |
| 31 | |
| 32 | time.Sleep(time.Second) |
| 33 | limiter.cleanupVisitors() |
| 34 | assert.Empty(t, limiter.visitors) |
| 35 | } |
| 36 | |
| 37 | func TestMiddleware(t *testing.T) { |
| 38 | gin.SetMode(gin.TestMode) |
| 39 | tests := map[string]struct { |
| 40 | config Config |
| 41 | requests int |
| 42 | expectedCode int |
| 43 | }{ |
| 44 | "allows requests with in limit": { |
| 45 | config: Config{ |
| 46 | RPS: 2, |
| 47 | Burst: 2, |
| 48 | TTL: time.Minute, |
| 49 | }, |
| 50 | requests: 1, |
| 51 | expectedCode: http.StatusOK, |
| 52 | }, |
| 53 | "blocks requests over limit": { |
| 54 | config: Config{ |
| 55 | RPS: 1, |
| 56 | Burst: 1, |
| 57 | TTL: time.Minute, |
| 58 | }, |
| 59 | requests: 2, |
| 60 | expectedCode: http.StatusTooManyRequests, |
| 61 | }, |
| 62 | "allows burst requests": { |
| 63 | config: Config{ |
| 64 | RPS: 1, |
| 65 | Burst: 3, |
| 66 | TTL: time.Minute, |
| 67 | }, |
| 68 | requests: 3, |
| 69 | expectedCode: http.StatusOK, |
| 70 | }, |
| 71 | } |
| 72 | |
| 73 | for name, tt := range tests { |
| 74 | t.Run(name, func(t *testing.T) { |
| 75 | handler := MiddlewareWithConfig(tt.config) |
| 76 | var lastCode int |
| 77 | |
| 78 | for range tt.requests { |
| 79 | w := httptest.NewRecorder() |
| 80 | c, _ := gin.CreateTestContext(w) |
| 81 | c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| 82 | |
| 83 | handler(c) |
| 84 | lastCode = w.Code |
| 85 | } |
| 86 | |
| 87 | assert.Equal(t, tt.expectedCode, lastCode) |
| 88 | }) |
| 89 | } |
| 90 | } |