package middleware import ( "net/http" "strconv" "sync" "sync/atomic" "time" "log/slog" ) // GlobalRateLimitConfig controls server-wide rate limiting. // Applied on top of per-IP rate limiting to prevent overall abuse. type GlobalRateLimitConfig struct { // Requests is the max total requests across all IPs per window. Requests int // Window is the time window duration (e.g. "1m"). Window time.Duration } // GlobalRateLimit returns a middleware that limits total server-wide requests. // Uses a lock-free atomic counter for minimal overhead. // Set requests to 0 to disable. func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { requests := cfg.Requests if requests <= 0 { return func(next http.Handler) http.Handler { return next } } window := cfg.Window if window <= 0 { window = time.Minute } if logger == nil { logger = slog.Default() } limiter := &globalLimiter{ requests: int64(requests), window: window, logger: logger, } go limiter.resetLoop() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !limiter.allow() { retryAfter := int(limiter.window.Seconds()) w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n")) logger.Warn("global rate limit exceeded", "ip", extractIP(r)) return } next.ServeHTTP(w, r) }) } } type globalLimiter struct { requests int64 count atomic.Int64 window time.Duration logger *slog.Logger } func (l *globalLimiter) allow() bool { for { current := l.count.Load() if current >= l.requests { return false } if l.count.CompareAndSwap(current, current+1) { return true } } } func (l *globalLimiter) resetLoop() { ticker := time.NewTicker(l.window) defer ticker.Stop() for range ticker.C { old := l.count.Swap(0) if old > 0 { l.logger.Debug("global rate limit window reset", "previous_count", old) } } } // BurstRateLimitConfig controls per-IP burst + sustained rate limiting. // More aggressive than the standard per-IP limiter — designed to stop rapid-fire abuse. type BurstRateLimitConfig struct { // Burst is the max requests allowed in a short burst window. Burst int // BurstWindow is the burst window duration (e.g. "5s"). BurstWindow time.Duration // Sustained is the max requests allowed in the sustained window. Sustained int // SustainedWindow is the sustained window duration (e.g. "1m"). SustainedWindow time.Duration } // BurstRateLimit returns a middleware that enforces both burst and sustained limits per IP. // Set burst to 0 to disable entirely. func BurstRateLimit(cfg BurstRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { if cfg.Burst <= 0 && cfg.Sustained <= 0 { return func(next http.Handler) http.Handler { return next } } if cfg.Burst <= 0 { cfg.Burst = cfg.Sustained } if cfg.BurstWindow <= 0 { cfg.BurstWindow = 5 * time.Second } if cfg.Sustained <= 0 { cfg.Sustained = cfg.Burst * 6 } if cfg.SustainedWindow <= 0 { cfg.SustainedWindow = time.Minute } if logger == nil { logger = slog.Default() } limiter := &burstLimiter{ burst: cfg.Burst, burstWindow: cfg.BurstWindow, sustained: cfg.Sustained, sustainedWindow: cfg.SustainedWindow, clients: make(map[string]*burstBucket), logger: logger, } go limiter.cleanup() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := extractIP(r) reason := limiter.allow(ip) if reason != "" { retryAfter := int(cfg.BurstWindow.Seconds()) w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("X-RateLimit-Reason", reason) w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte("429 Too Many Requests — " + reason + "\n")) logger.Debug("burst rate limited", "ip", ip, "reason", reason) return } next.ServeHTTP(w, r) }) } } type burstBucket struct { burstCount int sustainedCount int burstExpireAt time.Time sustainedExpireAt time.Time } type burstLimiter struct { burst int burstWindow time.Duration sustained int sustainedWindow time.Duration clients map[string]*burstBucket mu sync.Mutex logger *slog.Logger } // allow returns empty string if allowed, or a reason string if blocked. func (l *burstLimiter) allow(ip string) string { l.mu.Lock() defer l.mu.Unlock() now := time.Now() b, ok := l.clients[ip] if ok && now.After(b.sustainedExpireAt) { delete(l.clients, ip) ok = false } if !ok { l.clients[ip] = &burstBucket{ burstCount: 1, sustainedCount: 1, burstExpireAt: now.Add(l.burstWindow), sustainedExpireAt: now.Add(l.sustainedWindow), } return "" } if now.After(b.burstExpireAt) { b.burstCount = 0 b.burstExpireAt = now.Add(l.burstWindow) } b.burstCount++ b.sustainedCount++ if b.burstCount > l.burst { return "burst limit exceeded" } if b.sustainedCount > l.sustained { return "sustained limit exceeded" } return "" } func (l *burstLimiter) cleanup() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for range ticker.C { l.mu.Lock() now := time.Now() for ip, b := range l.clients { if now.After(b.sustainedExpireAt) { delete(l.clients, ip) } } l.mu.Unlock() } }