all repos

onasty @ e2c4fd283ecd1a92895e34c43cf02a3a94d2b194

a one-time notes service
7 files changed, 178 insertions(+), 8 deletions(-)
feat: rate limiting (#36)

* feat(ratelimit): add rate limiting module

* feat(ratelimit): setup rate limiter

* fixup! feat(ratelimit): setup rate limiter

* fixup! fixup! feat(ratelimit): setup rate limiter

* fixup! feat(ratelimit): setup rate limiter

* fix(e2e): set ratelimiter config in tests so they would pass

* refactor(ratelimit): remove unused code

* fix(ratelimit): now the middleware shouldn't panic

* fix(ratelimit): actually handle if user isnt found

* refactor(ratelimit): use rw mutex

* chore: update .env.example
Author: Smirnov Oleksandr ss2316544@gmail.com
Committed by: GitHub noreply@github.com
Committed at: 2024-10-18 23:58:44 +0300
Parent: 47d33af
M .env.example
···
        24
        24
         MAILGUN_DOMAI='<domain>'

      
        25
        25
         MAILGUN_API_KEY='<token>'

      
        26
        26
         VERIFICATION_TOKEN_TTL=48h

      
        
        27
        +

      
        
        28
        +RATELIMITER_RPS=100

      
        
        29
        +RATELIMITER_BURST=10

      
        
        30
        +RATELIMITER_TTL=3m

      
M cmd/server/main.go
···
        25
        25
         	"github.com/olexsmir/onasty/internal/store/psqlutil"

      
        26
        26
         	httptransport "github.com/olexsmir/onasty/internal/transport/http"

      
        27
        27
         	"github.com/olexsmir/onasty/internal/transport/http/httpserver"

      
        
        28
        +	"github.com/olexsmir/onasty/internal/transport/http/ratelimit"

      
        28
        29
         )

      
        29
        30
         

      
        30
        31
         func main() {

      ···
        83
        84
         	noterepo := noterepo.New(psqlDB)

      
        84
        85
         	notesrv := notesrv.New(noterepo)

      
        85
        86
         

      
        86
        
        -	handler := httptransport.NewTransport(usersrv, notesrv)

      
        
        87
        +	rateLimiterConfig := ratelimit.Config{

      
        
        88
        +		RPS:   cfg.RateLimiterRPS,

      
        
        89
        +		TTL:   cfg.RateLimiterTTL,

      
        
        90
        +		Burst: cfg.RateLimiterBurst,

      
        
        91
        +	}

      
        
        92
        +

      
        
        93
        +	handler := httptransport.NewTransport(

      
        
        94
        +		usersrv,

      
        
        95
        +		notesrv,

      
        
        96
        +		rateLimiterConfig,

      
        
        97
        +	)

      
        87
        98
         

      
        88
        99
         	// http server

      
        89
        100
         	srv := httpserver.NewServer(cfg.ServerPort, handler.Handler())

      
M e2e/e2e_test.go
···
        26
        26
         	"github.com/olexsmir/onasty/internal/store/psql/vertokrepo"

      
        27
        27
         	"github.com/olexsmir/onasty/internal/store/psqlutil"

      
        28
        28
         	httptransport "github.com/olexsmir/onasty/internal/transport/http"

      
        
        29
        +	"github.com/olexsmir/onasty/internal/transport/http/ratelimit"

      
        29
        30
         	"github.com/stretchr/testify/require"

      
        30
        31
         	"github.com/stretchr/testify/suite"

      
        31
        32
         	"github.com/testcontainers/testcontainers-go"

      ···
        117
        118
         	noterepo := noterepo.New(e.postgresDB)

      
        118
        119
         	notesrv := notesrv.New(noterepo)

      
        119
        120
         

      
        120
        
        -	handler := httptransport.NewTransport(usersrv, notesrv)

      
        
        121
        +	// for testing purposes, it's ok to have high values ig

      
        
        122
        +	ratelimitCfg := ratelimit.Config{

      
        
        123
        +		RPS:   1000,

      
        
        124
        +		TTL:   time.Millisecond,

      
        
        125
        +		Burst: 1000,

      
        
        126
        +	}

      
        
        127
        +

      
        
        128
        +	handler := httptransport.NewTransport(usersrv, notesrv, ratelimitCfg)

      
        121
        129
         	e.router = handler.Handler()

      
        122
        130
         }

      
        123
        131
         

      
M go.mod
···
        15
        15
         	github.com/stretchr/testify v1.9.0

      
        16
        16
         	github.com/testcontainers/testcontainers-go v0.33.0

      
        17
        17
         	github.com/testcontainers/testcontainers-go/modules/postgres v0.33.0

      
        
        18
        +	golang.org/x/time v0.5.0

      
        18
        19
         )

      
        19
        20
         

      
        20
        21
         require (

      
M internal/config/config.go
···
        3
        3
         import (

      
        4
        4
         	"errors"

      
        5
        5
         	"os"

      
        
        6
        +	"strconv"

      
        6
        7
         	"time"

      
        7
        8
         )

      
        8
        9
         

      ···
        28
        29
         	LogLevel    string

      
        29
        30
         	LogFormat   string

      
        30
        31
         	LogShowLine bool

      
        
        32
        +

      
        
        33
        +	RateLimiterRPS   int

      
        
        34
        +	RateLimiterBurst int

      
        
        35
        +	RateLimiterTTL   time.Duration

      
        31
        36
         }

      
        32
        37
         

      
        33
        38
         func NewConfig() *Config {

      ···
        39
        44
         		PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""),

      
        40
        45
         

      
        41
        46
         		JwtSigningKey: getenvOrDefault("JWT_SIGNING_KEY", ""),

      
        42
        
        -		JwtAccessTokenTTL: mustParseDurationOrPanic(

      
        
        47
        +		JwtAccessTokenTTL: mustParseDuration(

      
        43
        48
         			getenvOrDefault("JWT_ACCESS_TOKEN_TTL", "15m"),

      
        44
        49
         		),

      
        45
        
        -		JwtRefreshTokenTTL: mustParseDurationOrPanic(

      
        
        50
        +		JwtRefreshTokenTTL: mustParseDuration(

      
        46
        51
         			getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"),

      
        47
        52
         		),

      
        48
        53
         

      
        49
        54
         		MailgunFrom:   getenvOrDefault("MAILGUN_FROM", ""),

      
        50
        55
         		MailgunDomain: getenvOrDefault("MAILGUN_DOMAIN", ""),

      
        51
        56
         		MailgunAPIKey: getenvOrDefault("MAILGUN_API_KEY", ""),

      
        52
        
        -		VerificationTokenTTL: mustParseDurationOrPanic(

      
        
        57
        +		VerificationTokenTTL: mustParseDuration(

      
        53
        58
         			getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"),

      
        54
        59
         		),

      
        55
        60
         

      ···
        59
        64
         		LogLevel:    getenvOrDefault("LOG_LEVEL", "debug"),

      
        60
        65
         		LogFormat:   getenvOrDefault("LOG_FORMAT", "json"),

      
        61
        66
         		LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true",

      
        
        67
        +

      
        
        68
        +		RateLimiterRPS:   mustGetenvOrDefaultInt("RATELIMITER_RPS", 100),

      
        
        69
        +		RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10),

      
        
        70
        +		RateLimiterTTL:   mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")),

      
        62
        71
         	}

      
        63
        72
         }

      
        64
        73
         

      ···
        73
        82
         	return def

      
        74
        83
         }

      
        75
        84
         

      
        76
        
        -func mustParseDurationOrPanic(dur string) time.Duration {

      
        
        85
        +func mustGetenvOrDefaultInt(key string, def int) int {

      
        
        86
        +	if v, ok := os.LookupEnv(key); ok {

      
        
        87
        +		r, err := strconv.Atoi(v)

      
        
        88
        +		if err != nil {

      
        
        89
        +			panic(err)

      
        
        90
        +		}

      
        
        91
        +		return r

      
        
        92
        +	}

      
        
        93
        +	return def

      
        
        94
        +}

      
        
        95
        +

      
        
        96
        +func mustParseDuration(dur string) time.Duration {

      
        77
        97
         	d, err := time.ParseDuration(dur)

      
        78
        98
         	if err != nil {

      
        79
        99
         		panic(errors.Join(errors.New("cannot time.ParseDuration"), err)) //nolint:err113

      
M internal/transport/http/http.go
···
        7
        7
         	"github.com/olexsmir/onasty/internal/service/notesrv"

      
        8
        8
         	"github.com/olexsmir/onasty/internal/service/usersrv"

      
        9
        9
         	"github.com/olexsmir/onasty/internal/transport/http/apiv1"

      
        
        10
        +	"github.com/olexsmir/onasty/internal/transport/http/ratelimit"

      
        10
        11
         	"github.com/olexsmir/onasty/internal/transport/http/reqid"

      
        11
        12
         )

      
        12
        13
         

      
        13
        14
         type Transport struct {

      
        14
        15
         	usersrv usersrv.UserServicer

      
        15
        16
         	notesrv notesrv.NoteServicer

      
        
        17
        +

      
        
        18
        +	ratelimitCfg ratelimit.Config

      
        16
        19
         }

      
        17
        20
         

      
        18
        21
         func NewTransport(

      
        19
        22
         	us usersrv.UserServicer,

      
        20
        23
         	ns notesrv.NoteServicer,

      
        
        24
        +	ratelimitCfg ratelimit.Config,

      
        21
        25
         ) *Transport {

      
        22
        26
         	return &Transport{

      
        23
        
        -		usersrv: us,

      
        24
        
        -		notesrv: ns,

      
        
        27
        +		usersrv:      us,

      
        
        28
        +		notesrv:      ns,

      
        
        29
        +		ratelimitCfg: ratelimitCfg,

      
        25
        30
         	}

      
        26
        31
         }

      
        27
        32
         

      ···
        31
        36
         		gin.Recovery(),

      
        32
        37
         		reqid.Middleware(),

      
        33
        38
         		t.logger(),

      
        
        39
        +		ratelimit.MiddlewareWithConfig(t.ratelimitCfg),

      
        34
        40
         	)

      
        35
        41
         

      
        36
        42
         	api := r.Group("/api")

      
A internal/transport/http/ratelimit/ratelimit.go
···
        
        1
        +// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests

      
        
        2
        +

      
        
        3
        +package ratelimit

      
        
        4
        +

      
        
        5
        +import (

      
        
        6
        +	"net/http"

      
        
        7
        +	"sync"

      
        
        8
        +	"time"

      
        
        9
        +

      
        
        10
        +	"github.com/gin-gonic/gin"

      
        
        11
        +	"golang.org/x/time/rate"

      
        
        12
        +)

      
        
        13
        +

      
        
        14
        +type (

      
        
        15
        +	rateLimiter struct {

      
        
        16
        +		mu sync.RWMutex

      
        
        17
        +

      
        
        18
        +		visitors map[visitorIP]*visitor

      
        
        19
        +

      
        
        20
        +		// limit is the maximum number of requests per second

      
        
        21
        +		limit rate.Limit

      
        
        22
        +

      
        
        23
        +		// ttl is the time after which a visitor is forgotten

      
        
        24
        +		ttl time.Duration

      
        
        25
        +

      
        
        26
        +		// burst is the maximum number of requests that can be made in a short amount of time

      
        
        27
        +		burst int

      
        
        28
        +	}

      
        
        29
        +

      
        
        30
        +	visitorIP string

      
        
        31
        +	visitor   struct {

      
        
        32
        +		limiter  *rate.Limiter

      
        
        33
        +		lastSeen time.Time

      
        
        34
        +	}

      
        
        35
        +)

      
        
        36
        +

      
        
        37
        +func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter {

      
        
        38
        +	return &rateLimiter{ //nolint:exhaustruct

      
        
        39
        +		visitors: make(map[visitorIP]*visitor),

      
        
        40
        +		limit:    rate.Limit(rps),

      
        
        41
        +		burst:    burst,

      
        
        42
        +		ttl:      ttl,

      
        
        43
        +	}

      
        
        44
        +}

      
        
        45
        +

      
        
        46
        +// Retrieve and return the rate limiter for the current visitor if it

      
        
        47
        +// already exists. Otherwise create a new rate limiter and add it to

      
        
        48
        +// the visitors map, using the IP address as the key.

      
        
        49
        +func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter {

      
        
        50
        +	r.mu.RLock()

      
        
        51
        +	v, exists := r.visitors[ip]

      
        
        52
        +	r.mu.RUnlock()

      
        
        53
        +

      
        
        54
        +	if !exists {

      
        
        55
        +		limit := rate.NewLimiter(r.limit, r.burst)

      
        
        56
        +

      
        
        57
        +		r.mu.Lock()

      
        
        58
        +		r.visitors[ip] = &visitor{

      
        
        59
        +			limiter:  limit,

      
        
        60
        +			lastSeen: time.Now(),

      
        
        61
        +		}

      
        
        62
        +		r.mu.Unlock()

      
        
        63
        +

      
        
        64
        +		return limit

      
        
        65
        +	}

      
        
        66
        +

      
        
        67
        +	r.mu.Lock()

      
        
        68
        +	v.lastSeen = time.Now()

      
        
        69
        +	r.mu.Unlock()

      
        
        70
        +

      
        
        71
        +	return v.limiter

      
        
        72
        +}

      
        
        73
        +

      
        
        74
        +// Every minute check the map for visitors that haven't been seen for

      
        
        75
        +// more than 3 minutes and delete the entries.

      
        
        76
        +func (r *rateLimiter) cleanupVisitors() {

      
        
        77
        +	for {

      
        
        78
        +		time.Sleep(time.Minute)

      
        
        79
        +

      
        
        80
        +		r.mu.Lock()

      
        
        81
        +		for ip, v := range r.visitors {

      
        
        82
        +			if time.Since(v.lastSeen) > r.ttl {

      
        
        83
        +				delete(r.visitors, ip)

      
        
        84
        +			}

      
        
        85
        +		}

      
        
        86
        +		r.mu.Unlock()

      
        
        87
        +	}

      
        
        88
        +}

      
        
        89
        +

      
        
        90
        +type Config struct {

      
        
        91
        +	// RPS is the maximum number of requests per second

      
        
        92
        +	RPS int

      
        
        93
        +

      
        
        94
        +	// TTL is the time after which a visitor is forgotten

      
        
        95
        +	TTL time.Duration

      
        
        96
        +

      
        
        97
        +	// Burst is the maximum number of requests that can be made in a short amount of time

      
        
        98
        +	Burst int

      
        
        99
        +}

      
        
        100
        +

      
        
        101
        +// MiddlewareWithConfig returns a new rate limiting middleware with the given config

      
        
        102
        +func MiddlewareWithConfig(c Config) gin.HandlerFunc {

      
        
        103
        +	lmt := newLimiter(c.RPS, c.Burst, c.TTL)

      
        
        104
        +	go lmt.cleanupVisitors()

      
        
        105
        +

      
        
        106
        +	return func(c *gin.Context) {

      
        
        107
        +		visitor := lmt.getVisitor(visitorIP(c.ClientIP()))

      
        
        108
        +		if visitor == nil {

      
        
        109
        +			c.AbortWithStatus(http.StatusInternalServerError)

      
        
        110
        +			return

      
        
        111
        +		}

      
        
        112
        +

      
        
        113
        +		if !visitor.Allow() {

      
        
        114
        +			c.AbortWithStatus(http.StatusTooManyRequests)

      
        
        115
        +			return

      
        
        116
        +		}

      
        
        117
        +

      
        
        118
        +		c.Next()

      
        
        119
        +	}

      
        
        120
        +}