Three layers of rate limiting, all disabled by default, opt-in via config: 1. Per-IP (existing): 30 req/min per IP 2. Global: server-wide limit across all IPs - Lock-free atomic counter for minimal overhead - Returns 503 when exceeded - Prevents pool exhaustion from distributed attacks 3. Burst: per-IP burst + sustained windows - Blocks rapid-fire abuse within seconds - Returns 429 with X-RateLimit-Reason header - Example: 5 req/5s burst, 60 req/min sustained Config: [global_rate_limit] requests = 0 # disabled by default window = "1m" [burst_rate_limit] burst = 0 # disabled by default burst_window = "5s" sustained = 0 sustained_window = "1m" Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW, BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW, BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW Full test coverage: concurrent lock-free test, window expiry, disabled states, IP isolation, burst vs sustained distinction.
137 lines
2.8 KiB
Go
137 lines
2.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"log/slog"
|
|
)
|
|
|
|
// RateLimitConfig controls per-IP rate limiting using a sliding window counter.
|
|
type RateLimitConfig struct {
|
|
// Requests is the max number of requests allowed per window.
|
|
Requests int
|
|
// Window is the time window duration (e.g. "1m").
|
|
Window time.Duration
|
|
// CleanupInterval is how often stale entries are purged (default: 5m).
|
|
CleanupInterval time.Duration
|
|
}
|
|
|
|
// RateLimit returns a middleware that limits requests per IP address.
|
|
// Uses an in-memory sliding window counter. When the limit is exceeded,
|
|
// responds with HTTP 429 and a Retry-After header.
|
|
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
|
requests := cfg.Requests
|
|
if requests <= 0 {
|
|
requests = 30
|
|
}
|
|
|
|
window := cfg.Window
|
|
if window <= 0 {
|
|
window = time.Minute
|
|
}
|
|
|
|
cleanup := cfg.CleanupInterval
|
|
if cleanup <= 0 {
|
|
cleanup = 5 * time.Minute
|
|
}
|
|
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
|
|
limiter := &ipLimiter{
|
|
requests: requests,
|
|
window: window,
|
|
clients: make(map[string]*bucket),
|
|
logger: logger,
|
|
}
|
|
|
|
go limiter.cleanup(cleanup)
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := extractIP(r)
|
|
|
|
if !limiter.allow(ip) {
|
|
retryAfter := int(limiter.window.Seconds())
|
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_, _ = w.Write([]byte("429 Too Many Requests\n"))
|
|
logger.Debug("rate limited", "ip", ip)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type bucket struct {
|
|
count int
|
|
expireAt time.Time
|
|
}
|
|
|
|
type ipLimiter struct {
|
|
requests int
|
|
window time.Duration
|
|
clients map[string]*bucket
|
|
mu sync.Mutex
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (l *ipLimiter) allow(ip string) bool {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
b, ok := l.clients[ip]
|
|
|
|
if !ok || now.After(b.expireAt) {
|
|
l.clients[ip] = &bucket{
|
|
count: 1,
|
|
expireAt: now.Add(l.window),
|
|
}
|
|
return true
|
|
}
|
|
|
|
b.count++
|
|
return b.count <= l.requests
|
|
}
|
|
|
|
func (l *ipLimiter) cleanup(interval time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
l.mu.Lock()
|
|
now := time.Now()
|
|
for ip, b := range l.clients {
|
|
if now.After(b.expireAt) {
|
|
delete(l.clients, ip)
|
|
}
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func extractIP(r *http.Request) string {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.SplitN(xff, ",", 2)
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
|
return strings.TrimSpace(rip)
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|