package middleware import ( "net" "net/http" "net/http/httptest" "testing" "time" ) func TestRateLimit_AllowsUnderLimit(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 5, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } } func TestRateLimit_BlocksOverLimit(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 3, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } // 4th request should be blocked. req := httptest.NewRequest("GET", "/search?q=test", nil) req.RemoteAddr = "1.2.3.4:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec.Code) } } func TestRateLimit_DifferentIPs(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 1, Window: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // IP A: allowed req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("IP A first request: expected 200, got %d", rec.Code) } // IP A: blocked req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("IP A second request: expected 429, got %d", rec.Code) } // IP B: allowed (separate bucket) req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "2.2.2.2:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("IP B first request: expected 200, got %d", rec.Code) } } func TestRateLimit_XForwardedFor(t *testing.T) { h := RateLimit(RateLimitConfig{ Requests: 1, Window: 10 * time.Second, TrustedProxies: []string{"10.0.0.0/8"}, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Request via proxy — should use X-Forwarded-For. req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "10.0.0.1:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.1") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("first XFF request: expected 200, got %d", rec.Code) } // Different proxy, same client IP — should be blocked. req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "10.0.0.2:1234" req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("same XFF client: expected 429, got %d", rec.Code) } } func TestRateLimit_WindowExpires(t *testing.T) { limiter := &ipLimiter{ requests: 1, window: 50 * time.Millisecond, clients: make(map[string]*bucket), } if !limiter.allow("1.1.1.1") { t.Error("first request should be allowed") } if limiter.allow("1.1.1.1") { t.Error("second request should be blocked") } // Wait for window to expire. time.Sleep(60 * time.Millisecond) if !limiter.allow("1.1.1.1") { t.Error("request after window expiry should be allowed") } } func TestExtractIP(t *testing.T) { // Trusted proxy: loopback loopback := mustParseCIDR("127.0.0.0/8") privateNet := mustParseCIDR("10.0.0.0/8") tests := []struct { name string xff string realIP string remote string trusted []*net.IPNet expected string }{ // No trusted proxies → always use RemoteAddr. {"no_trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", nil, "10.0.0.1"}, {"no_trusted_real", "", "203.0.113.50", "10.0.0.1:1234", nil, "10.0.0.1"}, {"no_trusted_remote", "", "", "1.2.3.4:5678", nil, "1.2.3.4"}, // Trusted proxy → XFF is respected. {"trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"}, {"trusted_real_ip", "", "203.0.113.50", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"}, // Untrusted remote → XFF ignored even if present. {"untrusted_xff", "203.0.113.50, 10.0.0.1", "", "1.2.3.4:5678", []*net.IPNet{loopback}, "1.2.3.4"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) if tt.xff != "" { req.Header.Set("X-Forwarded-For", tt.xff) } if tt.realIP != "" { req.Header.Set("X-Real-IP", tt.realIP) } req.RemoteAddr = tt.remote if got := extractIP(req, tt.trusted...); got != tt.expected { t.Errorf("extractIP() = %q, want %q", got, tt.expected) } }) } } func mustParseCIDR(s string) *net.IPNet { _, network, err := net.ParseCIDR(s) if err != nil { panic(err) } return network }