kafka/internal/middleware/ratelimit.go
Franz Kafka 13040268d6 feat: add global and burst rate limiters
Three layers of rate limiting, all disabled by default, opt-in via config:

1. Per-IP (existing): 30 req/min per IP
2. Global: server-wide limit across all IPs
   - Lock-free atomic counter for minimal overhead
   - Returns 503 when exceeded
   - Prevents pool exhaustion from distributed attacks
3. Burst: per-IP burst + sustained windows
   - Blocks rapid-fire abuse within seconds
   - Returns 429 with X-RateLimit-Reason header
   - Example: 5 req/5s burst, 60 req/min sustained

Config:
[global_rate_limit]
requests = 0  # disabled by default
window = "1m"

[burst_rate_limit]
burst = 0  # disabled by default
burst_window = "5s"
sustained = 0
sustained_window = "1m"

Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW,
BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW,
BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW

Full test coverage: concurrent lock-free test, window expiry, disabled states,
IP isolation, burst vs sustained distinction.
2026-03-21 18:35:31 +00:00

137 lines
2.8 KiB
Go

package middleware
import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"log/slog"
)
// RateLimitConfig controls per-IP rate limiting using a sliding window counter.
type RateLimitConfig struct {
// Requests is the max number of requests allowed per window.
Requests int
// Window is the time window duration (e.g. "1m").
Window time.Duration
// CleanupInterval is how often stale entries are purged (default: 5m).
CleanupInterval time.Duration
}
// RateLimit returns a middleware that limits requests per IP address.
// Uses an in-memory sliding window counter. When the limit is exceeded,
// responds with HTTP 429 and a Retry-After header.
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
requests := cfg.Requests
if requests <= 0 {
requests = 30
}
window := cfg.Window
if window <= 0 {
window = time.Minute
}
cleanup := cfg.CleanupInterval
if cleanup <= 0 {
cleanup = 5 * time.Minute
}
if logger == nil {
logger = slog.Default()
}
limiter := &ipLimiter{
requests: requests,
window: window,
clients: make(map[string]*bucket),
logger: logger,
}
go limiter.cleanup(cleanup)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractIP(r)
if !limiter.allow(ip) {
retryAfter := int(limiter.window.Seconds())
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte("429 Too Many Requests\n"))
logger.Debug("rate limited", "ip", ip)
return
}
next.ServeHTTP(w, r)
})
}
}
type bucket struct {
count int
expireAt time.Time
}
type ipLimiter struct {
requests int
window time.Duration
clients map[string]*bucket
mu sync.Mutex
logger *slog.Logger
}
func (l *ipLimiter) allow(ip string) bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
b, ok := l.clients[ip]
if !ok || now.After(b.expireAt) {
l.clients[ip] = &bucket{
count: 1,
expireAt: now.Add(l.window),
}
return true
}
b.count++
return b.count <= l.requests
}
func (l *ipLimiter) cleanup(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for ip, b := range l.clients {
if now.After(b.expireAt) {
delete(l.clients, ip)
}
}
l.mu.Unlock()
}
}
func extractIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
return strings.TrimSpace(parts[0])
}
if rip := r.Header.Get("X-Real-IP"); rip != "" {
return strings.TrimSpace(rip)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}