package middleware import ( "net/http" "net/http/httptest" "testing" "time" ) func TestRateLimit_AllowsUnderLimit(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 5, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } } func TestRateLimit_BlocksOverLimit(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 3, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } // 4th request should be blocked. req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec.Code) } } func TestRateLimit_DifferentIPs(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 1, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // IP A: allowed req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("IP A first request: expected 200, got %d", rec.Code) } // IP A: blocked req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("IP A second request: expected 429, got %d", rec.Code) } // IP B: allowed (separate bucket) req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "2.2.2.2:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("IP B first request: expected 200, got %d", rec.Code) } } func TestRateLimit_XForwardedFor(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 1, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Request via proxy — should use X-Forwarded-For. req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "10.0.0.1:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.1") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("first XFF request: expected 200, got %d", rec.Code) } // Different proxy, same client IP — should be blocked. req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "10.0.0.2:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("same XFF client: expected 429, got %d", rec.Code) } } func TestRateLimit_WindowExpires(t *testing.T) { limiter := &ipLimiter{ requests: 1, window: 50 * time.Millisecond, clients: make(map[string]*bucket), } if !limiter.allow("1.1.1.1") { t.Error("first request should be allowed") } if limiter.allow("1.1.1.1") { t.Error("second request should be blocked") } // Wait for window to expire. time.Sleep(60 * time.Millisecond) if !limiter.allow("1.1.1.1") { t.Error("request after window expiry should be allowed") } } func TestExtractIP(t *testing.T) { tests := []struct { name string xff string realIP string remote string expected string }{ {"xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", "203.0.113.50"}, {"real_ip", "", "203.0.113.50", "10.0.0.1:1234", "203.0.113.50"}, {"remote", "", "", "1.2.3.4:5678", "1.2.3.4"}, {"xff_over_real", "203.0.113.50", "10.0.0.1", "10.0.0.1:1234", "203.0.113.50"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) if tt.xff != "" { req.Header.Set("X-Forwarded-For", tt.xff) } if tt.realIP != "" { req.Header.Set("X-Real-IP", tt.realIP) } req.RemoteAddr = tt.remote if got := extractIP(req); got != tt.expected { t.Errorf("extractIP() = %q, want %q", got, tt.expected) } }) } }