feat: add global and burst rate limiters

Three layers of rate limiting, all disabled by default, opt-in via config:

1. Per-IP (existing): 30 req/min per IP
2. Global: server-wide limit across all IPs
   - Lock-free atomic counter for minimal overhead
   - Returns 503 when exceeded
   - Prevents pool exhaustion from distributed attacks
3. Burst: per-IP burst + sustained windows
   - Blocks rapid-fire abuse within seconds
   - Returns 429 with X-RateLimit-Reason header
   - Example: 5 req/5s burst, 60 req/min sustained

Config:
[global_rate_limit]
requests = 0  # disabled by default
window = "1m"

[burst_rate_limit]
burst = 0  # disabled by default
burst_window = "5s"
sustained = 0
sustained_window = "1m"

Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW,
BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW,
BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW

Full test coverage: concurrent lock-free test, window expiry, disabled states,
IP isolation, burst vs sustained distinction.
This commit is contained in:
Franz Kafka 2026-03-21 18:35:31 +00:00
parent 91ab76758c
commit 13040268d6
7 changed files with 657 additions and 18 deletions

View file

@ -0,0 +1,241 @@
package middleware
import (
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"log/slog"
)
// GlobalRateLimitConfig controls server-wide rate limiting.
// Applied on top of per-IP rate limiting to prevent overall abuse.
type GlobalRateLimitConfig struct {
// Requests is the max total requests across all IPs per window.
Requests int
// Window is the time window duration (e.g. "1m").
Window time.Duration
}
// GlobalRateLimit returns a middleware that limits total server-wide requests.
// Uses a lock-free atomic counter for minimal overhead.
// Set requests to 0 to disable.
func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
requests := cfg.Requests
if requests <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
window := cfg.Window
if window <= 0 {
window = time.Minute
}
if logger == nil {
logger = slog.Default()
}
limiter := &globalLimiter{
requests: int64(requests),
window: window,
logger: logger,
}
go limiter.resetLoop()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.allow() {
retryAfter := int(limiter.window.Seconds())
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
logger.Warn("global rate limit exceeded", "ip", extractIP(r))
return
}
next.ServeHTTP(w, r)
})
}
}
type globalLimiter struct {
requests int64
count atomic.Int64
window time.Duration
logger *slog.Logger
}
func (l *globalLimiter) allow() bool {
for {
current := l.count.Load()
if current >= l.requests {
return false
}
if l.count.CompareAndSwap(current, current+1) {
return true
}
}
}
func (l *globalLimiter) resetLoop() {
ticker := time.NewTicker(l.window)
defer ticker.Stop()
for range ticker.C {
old := l.count.Swap(0)
if old > 0 {
l.logger.Debug("global rate limit window reset", "previous_count", old)
}
}
}
// BurstRateLimitConfig controls per-IP burst + sustained rate limiting.
// More aggressive than the standard per-IP limiter — designed to stop rapid-fire abuse.
type BurstRateLimitConfig struct {
// Burst is the max requests allowed in a short burst window.
Burst int
// BurstWindow is the burst window duration (e.g. "5s").
BurstWindow time.Duration
// Sustained is the max requests allowed in the sustained window.
Sustained int
// SustainedWindow is the sustained window duration (e.g. "1m").
SustainedWindow time.Duration
}
// BurstRateLimit returns a middleware that enforces both burst and sustained limits per IP.
// Set burst to 0 to disable entirely.
func BurstRateLimit(cfg BurstRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
if cfg.Burst <= 0 && cfg.Sustained <= 0 {
return func(next http.Handler) http.Handler {
return next
}
}
if cfg.Burst <= 0 {
cfg.Burst = cfg.Sustained
}
if cfg.BurstWindow <= 0 {
cfg.BurstWindow = 5 * time.Second
}
if cfg.Sustained <= 0 {
cfg.Sustained = cfg.Burst * 6
}
if cfg.SustainedWindow <= 0 {
cfg.SustainedWindow = time.Minute
}
if logger == nil {
logger = slog.Default()
}
limiter := &burstLimiter{
burst: cfg.Burst,
burstWindow: cfg.BurstWindow,
sustained: cfg.Sustained,
sustainedWindow: cfg.SustainedWindow,
clients: make(map[string]*burstBucket),
logger: logger,
}
go limiter.cleanup()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractIP(r)
reason := limiter.allow(ip)
if reason != "" {
retryAfter := int(cfg.BurstWindow.Seconds())
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-RateLimit-Reason", reason)
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte("429 Too Many Requests — " + reason + "\n"))
logger.Debug("burst rate limited", "ip", ip, "reason", reason)
return
}
next.ServeHTTP(w, r)
})
}
}
type burstBucket struct {
burstCount int
sustainedCount int
burstExpireAt time.Time
sustainedExpireAt time.Time
}
type burstLimiter struct {
burst int
burstWindow time.Duration
sustained int
sustainedWindow time.Duration
clients map[string]*burstBucket
mu sync.Mutex
logger *slog.Logger
}
// allow returns empty string if allowed, or a reason string if blocked.
func (l *burstLimiter) allow(ip string) string {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
b, ok := l.clients[ip]
if ok && now.After(b.sustainedExpireAt) {
delete(l.clients, ip)
ok = false
}
if !ok {
l.clients[ip] = &burstBucket{
burstCount: 1,
sustainedCount: 1,
burstExpireAt: now.Add(l.burstWindow),
sustainedExpireAt: now.Add(l.sustainedWindow),
}
return ""
}
if now.After(b.burstExpireAt) {
b.burstCount = 0
b.burstExpireAt = now.Add(l.burstWindow)
}
b.burstCount++
b.sustainedCount++
if b.burstCount > l.burst {
return "burst limit exceeded"
}
if b.sustainedCount > l.sustained {
return "sustained limit exceeded"
}
return ""
}
func (l *burstLimiter) cleanup() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for ip, b := range l.clients {
if now.After(b.sustainedExpireAt) {
delete(l.clients, ip)
}
}
l.mu.Unlock()
}
}