package middleware import ( "net" "net/http" "strconv" "strings" "sync" "time" "log/slog" ) // RateLimitConfig controls per-IP rate limiting using a sliding window counter. type RateLimitConfig struct { // Requests is the max number of requests allowed per window. Requests int // Window is the time window duration (e.g. "1m"). Window time.Duration // CleanupInterval is how often stale entries are purged (default: 5m). CleanupInterval time.Duration } // RateLimit returns a middleware that limits requests per IP address. // Uses an in-memory sliding window counter. When the limit is exceeded, // responds with HTTP 429 and a Retry-After header. func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { requests := cfg.Requests if requests <= 0 { requests = 30 } window := cfg.Window if window <= 0 { window = time.Minute } cleanup := cfg.CleanupInterval if cleanup <= 0 { cleanup = 5 * time.Minute } if logger == nil { logger = slog.Default() } limiter := &ipLimiter{ requests: requests, window: window, clients: make(map[string]*bucket), logger: logger, } // Background cleanup of stale buckets. go limiter.cleanup(cleanup) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := extractIP(r) if !limiter.allow(ip) { retryAfter := int(limiter.window.Seconds()) w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte("429 Too Many Requests\n")) logger.Debug("rate limited", "ip", ip) return } next.ServeHTTP(w, r) }) } } type bucket struct { count int expireAt time.Time } type ipLimiter struct { requests int window time.Duration clients map[string]*bucket mu sync.Mutex logger *slog.Logger } func (l *ipLimiter) allow(ip string) bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now() b, ok := l.clients[ip] if !ok || now.After(b.expireAt) { l.clients[ip] = &bucket{ count: 1, expireAt: now.Add(l.window), } return true } b.count++ return b.count <= l.requests } func (l *ipLimiter) cleanup(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { l.mu.Lock() now := time.Now() for ip, b := range l.clients { if now.After(b.expireAt) { delete(l.clients, ip) } } l.mu.Unlock() } } func extractIP(r *http.Request) string { // Trust X-Forwarded-For / X-Real-IP if behind a proxy. if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // First IP in the chain is the client. if idx := len(xff); idx > 0 { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } } if rip := r.Header.Get("X-Real-IP"); rip != "" { return strings.TrimSpace(rip) } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host }