Three layers of rate limiting, all disabled by default, opt-in via config: 1. Per-IP (existing): 30 req/min per IP 2. Global: server-wide limit across all IPs - Lock-free atomic counter for minimal overhead - Returns 503 when exceeded - Prevents pool exhaustion from distributed attacks 3. Burst: per-IP burst + sustained windows - Blocks rapid-fire abuse within seconds - Returns 429 with X-RateLimit-Reason header - Example: 5 req/5s burst, 60 req/min sustained Config: [global_rate_limit] requests = 0 # disabled by default window = "1m" [burst_rate_limit] burst = 0 # disabled by default burst_window = "5s" sustained = 0 sustained_window = "1m" Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW, BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW, BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW Full test coverage: concurrent lock-free test, window expiry, disabled states, IP isolation, burst vs sustained distinction.
170 lines
4.7 KiB
Go
170 lines
4.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestBurstRateLimit_AllowsUnderBurst(t *testing.T) {
|
|
h := BurstRateLimit(BurstRateLimitConfig{
|
|
Burst: 5,
|
|
BurstWindow: 10 * time.Second,
|
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 5; i++ {
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBurstRateLimit_BlocksBurst(t *testing.T) {
|
|
h := BurstRateLimit(BurstRateLimitConfig{
|
|
Burst: 3,
|
|
BurstWindow: 10 * time.Second,
|
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Errorf("expected 429, got %d", rec.Code)
|
|
}
|
|
reason := rec.Header().Get("X-RateLimit-Reason")
|
|
if reason != "burst limit exceeded" {
|
|
t.Errorf("expected burst reason, got %q", reason)
|
|
}
|
|
}
|
|
|
|
func TestBurstRateLimit_Sustained(t *testing.T) {
|
|
h := BurstRateLimit(BurstRateLimitConfig{
|
|
Burst: 100,
|
|
BurstWindow: 10 * time.Second,
|
|
Sustained: 5,
|
|
SustainedWindow: 10 * time.Second,
|
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 5; i++ {
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Errorf("expected 429, got %d", rec.Code)
|
|
}
|
|
reason := rec.Header().Get("X-RateLimit-Reason")
|
|
if reason != "sustained limit exceeded" {
|
|
t.Errorf("expected sustained reason, got %q", reason)
|
|
}
|
|
}
|
|
|
|
func TestBurstRateLimit_BurstWindowExpires(t *testing.T) {
|
|
limiter := &burstLimiter{
|
|
burst: 2,
|
|
burstWindow: 50 * time.Millisecond,
|
|
sustained: 100,
|
|
sustainedWindow: 10 * time.Second,
|
|
clients: make(map[string]*burstBucket),
|
|
}
|
|
|
|
// Use 2 requests (burst limit).
|
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
|
t.Fatalf("request 1 blocked: %s", reason)
|
|
}
|
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
|
t.Fatalf("request 2 blocked: %s", reason)
|
|
}
|
|
if reason := limiter.allow("1.1.1.1"); reason != "burst limit exceeded" {
|
|
t.Errorf("expected burst block, got %q", reason)
|
|
}
|
|
|
|
// Wait for burst window to expire.
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
|
t.Errorf("after burst expiry, should be allowed, got %q", reason)
|
|
}
|
|
}
|
|
|
|
func TestBurstRateLimit_Disabled(t *testing.T) {
|
|
h := BurstRateLimit(BurstRateLimitConfig{Burst: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 200; i++ {
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBurstRateLimit_DifferentIPs(t *testing.T) {
|
|
h := BurstRateLimit(BurstRateLimitConfig{
|
|
Burst: 1,
|
|
BurstWindow: 10 * time.Second,
|
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// IP A blocked after 1.
|
|
req := httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Error("IP A first: expected 200")
|
|
}
|
|
|
|
req = httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "1.1.1.1:1234"
|
|
rec = httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Error("IP A second: expected 429")
|
|
}
|
|
|
|
// IP B still allowed.
|
|
req = httptest.NewRequest("GET", "/search", nil)
|
|
req.RemoteAddr = "2.2.2.2:1234"
|
|
rec = httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Error("IP B first: expected 200")
|
|
}
|
|
}
|