package middleware import ( "net/http" "net/http/httptest" "testing" "time" ) func TestBurstRateLimit_AllowsUnderBurst(t *testing.T) { h := BurstRateLimit(BurstRateLimitConfig{ Burst: 5, BurstWindow: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } } func TestBurstRateLimit_BlocksBurst(t *testing.T) { h := BurstRateLimit(BurstRateLimitConfig{ Burst: 3, BurstWindow: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec.Code) } reason := rec.Header().Get("X-RateLimit-Reason") if reason != "burst limit exceeded" { t.Errorf("expected burst reason, got %q", reason) } } func TestBurstRateLimit_Sustained(t *testing.T) { h := BurstRateLimit(BurstRateLimitConfig{ Burst: 100, BurstWindow: 10 * time.Second, Sustained: 5, SustainedWindow: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code) } } req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec.Code) } reason := rec.Header().Get("X-RateLimit-Reason") if reason != "sustained limit exceeded" { t.Errorf("expected sustained reason, got %q", reason) } } func TestBurstRateLimit_BurstWindowExpires(t *testing.T) { limiter := &burstLimiter{ burst: 2, burstWindow: 50 * time.Millisecond, sustained: 100, sustainedWindow: 10 * time.Second, clients: make(map[string]*burstBucket), } // Use 2 requests (burst limit). if reason := limiter.allow("1.1.1.1"); reason != "" { t.Fatalf("request 1 blocked: %s", reason) } if reason := limiter.allow("1.1.1.1"); reason != "" { t.Fatalf("request 2 blocked: %s", reason) } if reason := limiter.allow("1.1.1.1"); reason != "burst limit exceeded" { t.Errorf("expected burst block, got %q", reason) } // Wait for burst window to expire. time.Sleep(60 * time.Millisecond) if reason := limiter.allow("1.1.1.1"); reason != "" { t.Errorf("after burst expiry, should be allowed, got %q", reason) } } func TestBurstRateLimit_Disabled(t *testing.T) { h := BurstRateLimit(BurstRateLimitConfig{Burst: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) for i := 0; i < 200; i++ { req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code) } } } func TestBurstRateLimit_DifferentIPs(t *testing.T) { h := BurstRateLimit(BurstRateLimitConfig{ Burst: 1, BurstWindow: 10 * time.Second, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // IP A blocked after 1. req := httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Error("IP A first: expected 200") } req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "1.1.1.1:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Error("IP A second: expected 429") } // IP B still allowed. req = httptest.NewRequest("GET", "/search", nil) req.RemoteAddr = "2.2.2.2:1234" rec = httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Error("IP B first: expected 200") } }