From 13040268d65705a0727eaea41e173198cb9af55a Mon Sep 17 00:00:00 2001 From: Franz Kafka Date: Sat, 21 Mar 2026 18:35:31 +0000 Subject: [PATCH] feat: add global and burst rate limiters Three layers of rate limiting, all disabled by default, opt-in via config: 1. Per-IP (existing): 30 req/min per IP 2. Global: server-wide limit across all IPs - Lock-free atomic counter for minimal overhead - Returns 503 when exceeded - Prevents pool exhaustion from distributed attacks 3. Burst: per-IP burst + sustained windows - Blocks rapid-fire abuse within seconds - Returns 429 with X-RateLimit-Reason header - Example: 5 req/5s burst, 60 req/min sustained Config: [global_rate_limit] requests = 0 # disabled by default window = "1m" [burst_rate_limit] burst = 0 # disabled by default burst_window = "5s" sustained = 0 sustained_window = "1m" Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW, BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW, BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW Full test coverage: concurrent lock-free test, window expiry, disabled states, IP isolation, burst vs sustained distinction. --- cmd/searxng-go/main.go | 12 +- config.example.toml | 25 ++ internal/config/config.go | 70 +++++- internal/middleware/ratelimit.go | 17 +- internal/middleware/ratelimit_burst_test.go | 170 +++++++++++++ internal/middleware/ratelimit_global.go | 241 +++++++++++++++++++ internal/middleware/ratelimit_global_test.go | 140 +++++++++++ 7 files changed, 657 insertions(+), 18 deletions(-) create mode 100644 internal/middleware/ratelimit_burst_test.go create mode 100644 internal/middleware/ratelimit_global.go create mode 100644 internal/middleware/ratelimit_global_test.go diff --git a/cmd/searxng-go/main.go b/cmd/searxng-go/main.go index f9ce301..d6fcaf3 100644 --- a/cmd/searxng-go/main.go +++ b/cmd/searxng-go/main.go @@ -73,7 +73,7 @@ func main() { var subFS fs.FS = staticFS mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS)))) - // Apply middleware: rate limiter → CORS → handler. + // Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler. var handler http.Handler = mux handler = middleware.CORS(middleware.CORSConfig{ AllowedOrigins: cfg.CORS.AllowedOrigins, @@ -87,6 +87,16 @@ func main() { Window: cfg.RateLimitWindow(), CleanupInterval: cfg.RateLimitCleanupInterval(), }, logger)(handler) + handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{ + Requests: cfg.GlobalRateLimit.Requests, + Window: cfg.GlobalRateLimitWindow(), + }, logger)(handler) + handler = middleware.BurstRateLimit(middleware.BurstRateLimitConfig{ + Burst: cfg.BurstRateLimit.Burst, + BurstWindow: cfg.BurstWindow(), + Sustained: cfg.BurstRateLimit.Sustained, + SustainedWindow: cfg.SustainedWindow(), + }, logger)(handler) addr := fmt.Sprintf(":%d", cfg.Server.Port) logger.Info("searxng-go starting", diff --git a/config.example.toml b/config.example.toml index b7ad762..f5cfa0b 100644 --- a/config.example.toml +++ b/config.example.toml @@ -66,3 +66,28 @@ requests = 30 window = "1m" # How often to clean up stale IP entries (env: RATE_LIMIT_CLEANUP_INTERVAL) cleanup_interval = "5m" + +[global_rate_limit] +# Server-wide rate limit across ALL IPs. Prevents pool exhaustion from +# distributed attacks even when per-IP limits are bypassed via VPNs. +# Returns 503 when exceeded. Set to 0 to disable. +# Env: GLOBAL_RATE_LIMIT_REQUESTS +requests = 0 +# Env: GLOBAL_RATE_LIMIT_WINDOW +window = "1m" + +[burst_rate_limit] +# Per-IP burst + sustained rate limiting. More aggressive than the standard +# per-IP limiter. Blocks rapid-fire abuse even if the per-minute limit isn't hit. +# Returns 429 with X-RateLimit-Reason header. Set burst to 0 to disable. +# +# Example: burst=5, burst_window="5s" means max 5 requests in any 5-second span. +# sustained=60, sustained_window="1m" means max 60 requests per minute. +# Env: BURST_RATE_LIMIT_BURST +burst = 0 +# Env: BURST_RATE_LIMIT_BURST_WINDOW +burst_window = "5s" +# Env: BURST_RATE_LIMIT_SUSTAINED +sustained = 0 +# Env: BURST_RATE_LIMIT_SUSTAINED_WINDOW +sustained_window = "1m" diff --git a/internal/config/config.go b/internal/config/config.go index b790200..e06d028 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,12 +11,14 @@ import ( // Config is the top-level configuration for the gosearch service. type Config struct { - Server ServerConfig `toml:"server"` - Upstream UpstreamConfig `toml:"upstream"` - Engines EnginesConfig `toml:"engines"` - Cache CacheConfig `toml:"cache"` - CORS CORSConfig `toml:"cors"` - RateLimit RateLimitConfig `toml:"rate_limit"` + Server ServerConfig `toml:"server"` + Upstream UpstreamConfig `toml:"upstream"` + Engines EnginesConfig `toml:"engines"` + Cache CacheConfig `toml:"cache"` + CORS CORSConfig `toml:"cors"` + RateLimit RateLimitConfig `toml:"rate_limit"` + GlobalRateLimit GlobalRateLimitConfig `toml:"global_rate_limit"` + BurstRateLimit BurstRateLimitConfig `toml:"burst_rate_limit"` } type ServerConfig struct { @@ -59,6 +61,20 @@ type RateLimitConfig struct { CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m") } +// GlobalRateLimitConfig holds server-wide rate limiting settings. +type GlobalRateLimitConfig struct { + Requests int `toml:"requests"` // Max total requests per window across all IPs (0 = disabled) + Window string `toml:"window"` // Time window (e.g. "1m", default: "1m") +} + +// BurstRateLimitConfig holds per-IP burst rate limiting settings. +type BurstRateLimitConfig struct { + Burst int `toml:"burst"` // Max requests in burst window (0 = disabled) + BurstWindow string `toml:"burst_window"` // Burst window (default: "5s") + Sustained int `toml:"sustained"` // Max requests in sustained window + SustainedWindow string `toml:"sustained_window"` // Sustained window (default: "1m") +} + type BraveConfig struct { APIKey string `toml:"api_key"` AccessToken string `toml:"access_token"` @@ -159,6 +175,24 @@ func applyEnvOverrides(cfg *Config) { if v := os.Getenv("RATE_LIMIT_CLEANUP_INTERVAL"); v != "" { cfg.RateLimit.CleanupInterval = v } + if v := os.Getenv("GLOBAL_RATE_LIMIT_REQUESTS"); v != "" { + fmt.Sscanf(v, "%d", &cfg.GlobalRateLimit.Requests) + } + if v := os.Getenv("GLOBAL_RATE_LIMIT_WINDOW"); v != "" { + cfg.GlobalRateLimit.Window = v + } + if v := os.Getenv("BURST_RATE_LIMIT_BURST"); v != "" { + fmt.Sscanf(v, "%d", &cfg.BurstRateLimit.Burst) + } + if v := os.Getenv("BURST_RATE_LIMIT_BURST_WINDOW"); v != "" { + cfg.BurstRateLimit.BurstWindow = v + } + if v := os.Getenv("BURST_RATE_LIMIT_SUSTAINED"); v != "" { + fmt.Sscanf(v, "%d", &cfg.BurstRateLimit.Sustained) + } + if v := os.Getenv("BURST_RATE_LIMIT_SUSTAINED_WINDOW"); v != "" { + cfg.BurstRateLimit.SustainedWindow = v + } if v := os.Getenv("BASE_URL"); v != "" { cfg.Server.BaseURL = v } @@ -201,6 +235,30 @@ func (c *Config) RateLimitCleanupInterval() time.Duration { return 5 * time.Minute } +// GlobalRateLimitWindow parses the global rate limit window into a time.Duration. +func (c *Config) GlobalRateLimitWindow() time.Duration { + if d, err := time.ParseDuration(c.GlobalRateLimit.Window); err == nil && d > 0 { + return d + } + return time.Minute +} + +// BurstWindow parses the burst window into a time.Duration. +func (c *Config) BurstWindow() time.Duration { + if d, err := time.ParseDuration(c.BurstRateLimit.BurstWindow); err == nil && d > 0 { + return d + } + return 5 * time.Second +} + +// SustainedWindow parses the sustained window into a time.Duration. +func (c *Config) SustainedWindow() time.Duration { + if d, err := time.ParseDuration(c.BurstRateLimit.SustainedWindow); err == nil && d > 0 { + return d + } + return time.Minute +} + func splitCSV(s string) []string { if s == "" { return nil diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index f1a181d..5102f9a 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -45,13 +45,12 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http } limiter := &ipLimiter{ - requests: requests, - window: window, - clients: make(map[string]*bucket), - logger: logger, + requests: requests, + window: window, + clients: make(map[string]*bucket), + logger: logger, } - // Background cleanup of stale buckets. go limiter.cleanup(cleanup) return func(next http.Handler) http.Handler { @@ -122,13 +121,9 @@ func (l *ipLimiter) cleanup(interval time.Duration) { } func extractIP(r *http.Request) string { - // Trust X-Forwarded-For / X-Real-IP if behind a proxy. if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - // First IP in the chain is the client. - if idx := len(xff); idx > 0 { - parts := strings.SplitN(xff, ",", 2) - return strings.TrimSpace(parts[0]) - } + parts := strings.SplitN(xff, ",", 2) + return strings.TrimSpace(parts[0]) } if rip := r.Header.Get("X-Real-IP"); rip != "" { return strings.TrimSpace(rip) diff --git a/internal/middleware/ratelimit_burst_test.go b/internal/middleware/ratelimit_burst_test.go new file mode 100644 index 0000000..a840db6 --- /dev/null +++ b/internal/middleware/ratelimit_burst_test.go @@ -0,0 +1,170 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestBurstRateLimit_AllowsUnderBurst(t *testing.T) { + h := BurstRateLimit(BurstRateLimitConfig{ + Burst: 5, + BurstWindow: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } +} + +func TestBurstRateLimit_BlocksBurst(t *testing.T) { + h := BurstRateLimit(BurstRateLimitConfig{ + Burst: 3, + BurstWindow: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 3; i++ { + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } + + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rec.Code) + } + reason := rec.Header().Get("X-RateLimit-Reason") + if reason != "burst limit exceeded" { + t.Errorf("expected burst reason, got %q", reason) + } +} + +func TestBurstRateLimit_Sustained(t *testing.T) { + h := BurstRateLimit(BurstRateLimitConfig{ + Burst: 100, + BurstWindow: 10 * time.Second, + Sustained: 5, + SustainedWindow: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } + + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rec.Code) + } + reason := rec.Header().Get("X-RateLimit-Reason") + if reason != "sustained limit exceeded" { + t.Errorf("expected sustained reason, got %q", reason) + } +} + +func TestBurstRateLimit_BurstWindowExpires(t *testing.T) { + limiter := &burstLimiter{ + burst: 2, + burstWindow: 50 * time.Millisecond, + sustained: 100, + sustainedWindow: 10 * time.Second, + clients: make(map[string]*burstBucket), + } + + // Use 2 requests (burst limit). + if reason := limiter.allow("1.1.1.1"); reason != "" { + t.Fatalf("request 1 blocked: %s", reason) + } + if reason := limiter.allow("1.1.1.1"); reason != "" { + t.Fatalf("request 2 blocked: %s", reason) + } + if reason := limiter.allow("1.1.1.1"); reason != "burst limit exceeded" { + t.Errorf("expected burst block, got %q", reason) + } + + // Wait for burst window to expire. + time.Sleep(60 * time.Millisecond) + + if reason := limiter.allow("1.1.1.1"); reason != "" { + t.Errorf("after burst expiry, should be allowed, got %q", reason) + } +} + +func TestBurstRateLimit_Disabled(t *testing.T) { + h := BurstRateLimit(BurstRateLimitConfig{Burst: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 200; i++ { + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code) + } + } +} + +func TestBurstRateLimit_DifferentIPs(t *testing.T) { + h := BurstRateLimit(BurstRateLimitConfig{ + Burst: 1, + BurstWindow: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // IP A blocked after 1. + req := httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Error("IP A first: expected 200") + } + + req = httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "1.1.1.1:1234" + rec = httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusTooManyRequests { + t.Error("IP A second: expected 429") + } + + // IP B still allowed. + req = httptest.NewRequest("GET", "/search", nil) + req.RemoteAddr = "2.2.2.2:1234" + rec = httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Error("IP B first: expected 200") + } +} diff --git a/internal/middleware/ratelimit_global.go b/internal/middleware/ratelimit_global.go new file mode 100644 index 0000000..1f94dc9 --- /dev/null +++ b/internal/middleware/ratelimit_global.go @@ -0,0 +1,241 @@ +package middleware + +import ( + "net/http" + "strconv" + "sync" + "sync/atomic" + "time" + + "log/slog" +) + +// GlobalRateLimitConfig controls server-wide rate limiting. +// Applied on top of per-IP rate limiting to prevent overall abuse. +type GlobalRateLimitConfig struct { + // Requests is the max total requests across all IPs per window. + Requests int + // Window is the time window duration (e.g. "1m"). + Window time.Duration +} + +// GlobalRateLimit returns a middleware that limits total server-wide requests. +// Uses a lock-free atomic counter for minimal overhead. +// Set requests to 0 to disable. +func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { + requests := cfg.Requests + if requests <= 0 { + return func(next http.Handler) http.Handler { + return next + } + } + + window := cfg.Window + if window <= 0 { + window = time.Minute + } + + if logger == nil { + logger = slog.Default() + } + + limiter := &globalLimiter{ + requests: int64(requests), + window: window, + logger: logger, + } + + go limiter.resetLoop() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !limiter.allow() { + retryAfter := int(limiter.window.Seconds()) + w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n")) + logger.Warn("global rate limit exceeded", "ip", extractIP(r)) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +type globalLimiter struct { + requests int64 + count atomic.Int64 + window time.Duration + logger *slog.Logger +} + +func (l *globalLimiter) allow() bool { + for { + current := l.count.Load() + if current >= l.requests { + return false + } + if l.count.CompareAndSwap(current, current+1) { + return true + } + } +} + +func (l *globalLimiter) resetLoop() { + ticker := time.NewTicker(l.window) + defer ticker.Stop() + + for range ticker.C { + old := l.count.Swap(0) + if old > 0 { + l.logger.Debug("global rate limit window reset", "previous_count", old) + } + } +} + +// BurstRateLimitConfig controls per-IP burst + sustained rate limiting. +// More aggressive than the standard per-IP limiter — designed to stop rapid-fire abuse. +type BurstRateLimitConfig struct { + // Burst is the max requests allowed in a short burst window. + Burst int + // BurstWindow is the burst window duration (e.g. "5s"). + BurstWindow time.Duration + // Sustained is the max requests allowed in the sustained window. + Sustained int + // SustainedWindow is the sustained window duration (e.g. "1m"). + SustainedWindow time.Duration +} + +// BurstRateLimit returns a middleware that enforces both burst and sustained limits per IP. +// Set burst to 0 to disable entirely. +func BurstRateLimit(cfg BurstRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { + if cfg.Burst <= 0 && cfg.Sustained <= 0 { + return func(next http.Handler) http.Handler { + return next + } + } + + if cfg.Burst <= 0 { + cfg.Burst = cfg.Sustained + } + if cfg.BurstWindow <= 0 { + cfg.BurstWindow = 5 * time.Second + } + if cfg.Sustained <= 0 { + cfg.Sustained = cfg.Burst * 6 + } + if cfg.SustainedWindow <= 0 { + cfg.SustainedWindow = time.Minute + } + + if logger == nil { + logger = slog.Default() + } + + limiter := &burstLimiter{ + burst: cfg.Burst, + burstWindow: cfg.BurstWindow, + sustained: cfg.Sustained, + sustainedWindow: cfg.SustainedWindow, + clients: make(map[string]*burstBucket), + logger: logger, + } + + go limiter.cleanup() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := extractIP(r) + reason := limiter.allow(ip) + + if reason != "" { + retryAfter := int(cfg.BurstWindow.Seconds()) + w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-RateLimit-Reason", reason) + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte("429 Too Many Requests — " + reason + "\n")) + logger.Debug("burst rate limited", "ip", ip, "reason", reason) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +type burstBucket struct { + burstCount int + sustainedCount int + burstExpireAt time.Time + sustainedExpireAt time.Time +} + +type burstLimiter struct { + burst int + burstWindow time.Duration + sustained int + sustainedWindow time.Duration + clients map[string]*burstBucket + mu sync.Mutex + logger *slog.Logger +} + +// allow returns empty string if allowed, or a reason string if blocked. +func (l *burstLimiter) allow(ip string) string { + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + b, ok := l.clients[ip] + + if ok && now.After(b.sustainedExpireAt) { + delete(l.clients, ip) + ok = false + } + + if !ok { + l.clients[ip] = &burstBucket{ + burstCount: 1, + sustainedCount: 1, + burstExpireAt: now.Add(l.burstWindow), + sustainedExpireAt: now.Add(l.sustainedWindow), + } + return "" + } + + if now.After(b.burstExpireAt) { + b.burstCount = 0 + b.burstExpireAt = now.Add(l.burstWindow) + } + + b.burstCount++ + b.sustainedCount++ + + if b.burstCount > l.burst { + return "burst limit exceeded" + } + if b.sustainedCount > l.sustained { + return "sustained limit exceeded" + } + + return "" +} + +func (l *burstLimiter) cleanup() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + l.mu.Lock() + now := time.Now() + for ip, b := range l.clients { + if now.After(b.sustainedExpireAt) { + delete(l.clients, ip) + } + } + l.mu.Unlock() + } +} diff --git a/internal/middleware/ratelimit_global_test.go b/internal/middleware/ratelimit_global_test.go new file mode 100644 index 0000000..0d21997 --- /dev/null +++ b/internal/middleware/ratelimit_global_test.go @@ -0,0 +1,140 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestGlobalRateLimit_AllowsUnderLimit(t *testing.T) { + h := GlobalRateLimit(GlobalRateLimitConfig{ + Requests: 100, + Window: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 100; i++ { + req := httptest.NewRequest("GET", "/search", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } +} + +func TestGlobalRateLimit_BlocksOverLimit(t *testing.T) { + h := GlobalRateLimit(GlobalRateLimitConfig{ + Requests: 5, + Window: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/search", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) + } + } + + req := httptest.NewRequest("GET", "/search", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("expected 503, got %d", rec.Code) + } +} + +func TestGlobalRateLimit_WindowResets(t *testing.T) { + limiter := &globalLimiter{ + requests: 3, + window: 50 * time.Millisecond, + } + + if !limiter.allow() { + t.Error("first request should be allowed") + } + if !limiter.allow() { + t.Error("second request should be allowed") + } + if !limiter.allow() { + t.Error("third request should be allowed") + } + if limiter.allow() { + t.Error("fourth request should be blocked") + } + + // Simulate window reset. + old := limiter.count.Swap(0) + if old != 3 { + t.Errorf("expected count 3, got %d", old) + } + + if !limiter.allow() { + t.Error("after reset, request should be allowed") + } +} + +func TestGlobalRateLimit_Disabled(t *testing.T) { + h := GlobalRateLimit(GlobalRateLimitConfig{Requests: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Should allow unlimited requests. + for i := 0; i < 200; i++ { + req := httptest.NewRequest("GET", "/search", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code) + } + } +} + +func TestGlobalRateLimit_LockFree(t *testing.T) { + // Verify the atomic counter doesn't panic under concurrent access. + h := GlobalRateLimit(GlobalRateLimitConfig{ + Requests: 10000, + Window: 10 * time.Second, + }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + var allowed, blocked atomic.Int64 + done := make(chan struct{}) + + for i := 0; i < 100; i++ { + go func() { + for j := 0; j < 200; j++ { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + allowed.Add(1) + } else { + blocked.Add(1) + } + } + done <- struct{}{} + }() + } + + for i := 0; i < 100; i++ { + <-done + } + + total := allowed.Load() + blocked.Load() + if total != 20000 { + t.Errorf("expected 20000 total requests, got %d", total) + } + if allowed.Load() != 10000 { + t.Errorf("expected exactly 10000 allowed, got %d", allowed.Load()) + } + t.Logf("concurrent test: %d allowed, %d blocked", allowed.Load(), blocked.Load()) +}