I'm new to rate limiting and want to use tollbooth to limit HTTP requests.
I also read the Token Bucket Algorithm page on Wikipedia.
For a simple test app, I want to limit the max number of concurrent requests to 10
regardless of request IP, and have a max burst size of 3
based on request IP.
NOTE: The 10
and 3
are just to make rate limiting easier to observe.
Below is my code based on the examples on tollbooth
's GitHub page:
package main
import (
"net/http"
"time"
"github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v7/limiter"
)
func main() {
lmt := tollbooth.NewLimiter(3, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour})
http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler))
http.ListenAndServe(":8080", nil)
}
func HelloHandler(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("Hello, World!"))
}
I test the code by running curl -i localhost:8080
several times in rapid succession, and I do get HTTP/1.1 429 Too Many Requests
errors whenever I exceed the rate limit I set.
Below are my questions:
How do I use
tollbooth
to limit max number of concurrent requests to something like10
? And does it even make sense to do so? I assume it does because rate limiting based only on IPs sounds like the server could still go out of memory when too many IPs access it at once.Am I approaching rate limiting correctly, or am I missing something? Perhaps this is something that's better handled by whatever load balancer is working with the app in the cloud?
UPDATE: Here's my working code based on Woody1193's answer:
package main
import (
"net/http"
"sync"
"time"
"github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v7/limiter"
)
func main() {
ipLimiter := tollbooth.NewLimiter(3, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour})
globalLimiter := NewConcurrentLimiter(10)
http.Handle("/", globalLimiter.LimitConcurrentRequests(ipLimiter, HelloHandler))
http.ListenAndServe(":8080", nil)
}
func HelloHandler(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("Hello, World!"))
}
type ConcurrentLimiter struct {
max int
current int
mut sync.Mutex
}
func NewConcurrentLimiter(limit int) *ConcurrentLimiter {
return &ConcurrentLimiter{
max: limit,
}
}
func (limiter *ConcurrentLimiter) LimitConcurrentRequests(lmt *limiter.Limiter,
handler func(http.ResponseWriter, *http.Request)) http.Handler {
middle := func(w http.ResponseWriter, r *http.Request) {
limiter.mut.Lock()
maxHit := limiter.current == limiter.max
if maxHit {
limiter.mut.Unlock()
http.Error(w, http.StatusText(429), http.StatusTooManyRequests)
return
}
limiter.current = 1
limiter.mut.Unlock()
defer func() {
limiter.mut.Lock()
limiter.current -= 1
limiter.mut.Unlock()
}()
// There's no rate-limit error, serve the next handler.
handler(w, r)
}
return tollbooth.LimitHandler(lmt, http.HandlerFunc(middle))
}
CodePudding user response:
It appears that tollbooth doesn't offer the functionality you're looking for. However, you can roll your own:
type ConcurrentLimiter struct {
max int
current int
mut sync.Mutex
}
func NewConcurrentLimiter(limit int) *ConcurrentLimiter {
return &ConcurrentLimiter {
max: limit,
mut: new(sync.Mutex),
}
}
func (limiter *ConcurrentLimiter) LimitConcurrentRequests(lmt *limiter.Limiter,
next http.Handler) http.Handler {
middle := func(w http.ResponseWriter, r *http.Request) {
limiter.mut.Lock()
maxHit := limiter.current == limiter.max
if maxHit {
limiter.mut.Unlock()
httpError := // Insert your HTTP error here
return
}
limiter.current = 1
limiter.mut.Unlock()
defer func() {
limiter.mut.Lock()
limiter.current -= 1
limiter.mut.Unlock()
}()
// There's no rate-limit error, serve the next handler.
next.ServeHTTP(w, r)
}
return tollbooth.LimitHandler(lmt, http.HandlerFunc(middle))
}
Then, in your setup you can do:
http.Handle("/", NewConcurrentLimiter(10).LimitConcurrentRequests(HelloHandler))
This code works by maintaining a value describing how many requests the API is currently handling and returning an error if the maximum value has been met. The Mutex
is used to ensure that the value is updated regardless of concurrent requests.
I had to inject the tollbooth.Limiter
into the limiter I wrote because of the way tollbooth handles such functions (i.e. it doesn't operate as a middleware).