kafka/internal/middleware/ratelimit_global_test.go
Franz Kafka 13040268d6 feat: add global and burst rate limiters
Three layers of rate limiting, all disabled by default, opt-in via config:

1. Per-IP (existing): 30 req/min per IP
2. Global: server-wide limit across all IPs
   - Lock-free atomic counter for minimal overhead
   - Returns 503 when exceeded
   - Prevents pool exhaustion from distributed attacks
3. Burst: per-IP burst + sustained windows
   - Blocks rapid-fire abuse within seconds
   - Returns 429 with X-RateLimit-Reason header
   - Example: 5 req/5s burst, 60 req/min sustained

Config:
[global_rate_limit]
requests = 0  # disabled by default
window = "1m"

[burst_rate_limit]
burst = 0  # disabled by default
burst_window = "5s"
sustained = 0
sustained_window = "1m"

Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW,
BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW,
BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW

Full test coverage: concurrent lock-free test, window expiry, disabled states,
IP isolation, burst vs sustained distinction.
2026-03-21 18:35:31 +00:00

140 lines
3.4 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
func TestGlobalRateLimit_AllowsUnderLimit(t *testing.T) {
h := GlobalRateLimit(GlobalRateLimitConfig{
Requests: 100,
Window: 10 * time.Second,
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < 100; i++ {
req := httptest.NewRequest("GET", "/search", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
}
}
}
func TestGlobalRateLimit_BlocksOverLimit(t *testing.T) {
h := GlobalRateLimit(GlobalRateLimitConfig{
Requests: 5,
Window: 10 * time.Second,
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/search", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
}
}
req := httptest.NewRequest("GET", "/search", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusServiceUnavailable {
t.Errorf("expected 503, got %d", rec.Code)
}
}
func TestGlobalRateLimit_WindowResets(t *testing.T) {
limiter := &globalLimiter{
requests: 3,
window: 50 * time.Millisecond,
}
if !limiter.allow() {
t.Error("first request should be allowed")
}
if !limiter.allow() {
t.Error("second request should be allowed")
}
if !limiter.allow() {
t.Error("third request should be allowed")
}
if limiter.allow() {
t.Error("fourth request should be blocked")
}
// Simulate window reset.
old := limiter.count.Swap(0)
if old != 3 {
t.Errorf("expected count 3, got %d", old)
}
if !limiter.allow() {
t.Error("after reset, request should be allowed")
}
}
func TestGlobalRateLimit_Disabled(t *testing.T) {
h := GlobalRateLimit(GlobalRateLimitConfig{Requests: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Should allow unlimited requests.
for i := 0; i < 200; i++ {
req := httptest.NewRequest("GET", "/search", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code)
}
}
}
func TestGlobalRateLimit_LockFree(t *testing.T) {
// Verify the atomic counter doesn't panic under concurrent access.
h := GlobalRateLimit(GlobalRateLimitConfig{
Requests: 10000,
Window: 10 * time.Second,
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
var allowed, blocked atomic.Int64
done := make(chan struct{})
for i := 0; i < 100; i++ {
go func() {
for j := 0; j < 200; j++ {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
allowed.Add(1)
} else {
blocked.Add(1)
}
}
done <- struct{}{}
}()
}
for i := 0; i < 100; i++ {
<-done
}
total := allowed.Load() + blocked.Load()
if total != 20000 {
t.Errorf("expected 20000 total requests, got %d", total)
}
if allowed.Load() != 10000 {
t.Errorf("expected exactly 10000 allowed, got %d", allowed.Load())
}
t.Logf("concurrent test: %d allowed, %d blocked", allowed.Load(), blocked.Load())
}