package middleware import ( "net/http" "net/http/httptest" "sync/atomic" "testing" "time" ) func TestGlobalRateLimit_AllowsUnderLimit(t *testing.T) { h := GlobalRateLimit(GlobalRateLimitConfig{ Requests: 100, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 100; i++ { req := httptest.NewRequest("GET", "/search", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } } func TestGlobalRateLimit_BlocksOverLimit(t *testing.T) { h := GlobalRateLimit(GlobalRateLimitConfig{ Requests: 5, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/search", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } req := httptest.NewRequest("GET", "/search", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusServiceUnavailable { t.Errorf("expected 503, got %d", rec.Code) } } func TestGlobalRateLimit_WindowResets(t *testing.T) { limiter := &globalLimiter{ requests: 3, window: 50 * time.Millisecond, } if !limiter.allow() { t.Error("first request should be allowed") } if !limiter.allow() { t.Error("second request should be allowed") } if !limiter.allow() { t.Error("third request should be allowed") } if limiter.allow() { t.Error("fourth request should be blocked") } // Simulate window reset. old := limiter.count.Swap(0) if old != 3 { t.Errorf("expected count 3, got %d", old) } if !limiter.allow() { t.Error("after reset, request should be allowed") } } func TestGlobalRateLimit_Disabled(t *testing.T) { h := GlobalRateLimit(GlobalRateLimitConfig{Requests: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Should allow unlimited requests. for i := 0; i < 200; i++ { req := httptest.NewRequest("GET", "/search", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code) } } } func TestGlobalRateLimit_LockFree(t *testing.T) { // Verify the atomic counter doesn't panic under concurrent access. h := GlobalRateLimit(GlobalRateLimitConfig{ Requests: 10000, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) var allowed, blocked atomic.Int64 done := make(chan struct{}) for i := 0; i < 100; i++ { go func() { for j := 0; j < 200; j++ { req := httptest.NewRequest("GET", "/", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code == http.StatusOK { allowed.Add(1) } else { blocked.Add(1) } } done <- struct{}{} }() } for i := 0; i < 100; i++ { <-done } total := allowed.Load() + blocked.Load() if total != 20000 { t.Errorf("expected 20000 total requests, got %d", total) } if allowed.Load() != 10000 { t.Errorf("expected exactly 10000 allowed, got %d", allowed.Load()) } t.Logf("concurrent test: %d allowed, %d blocked", allowed.Load(), blocked.Load()) }