feat: add global and burst rate limiters
Three layers of rate limiting, all disabled by default, opt-in via config: 1. Per-IP (existing): 30 req/min per IP 2. Global: server-wide limit across all IPs - Lock-free atomic counter for minimal overhead - Returns 503 when exceeded - Prevents pool exhaustion from distributed attacks 3. Burst: per-IP burst + sustained windows - Blocks rapid-fire abuse within seconds - Returns 429 with X-RateLimit-Reason header - Example: 5 req/5s burst, 60 req/min sustained Config: [global_rate_limit] requests = 0 # disabled by default window = "1m" [burst_rate_limit] burst = 0 # disabled by default burst_window = "5s" sustained = 0 sustained_window = "1m" Env overrides: GLOBAL_RATE_LIMIT_REQUESTS, GLOBAL_RATE_LIMIT_WINDOW, BURST_RATE_LIMIT_BURST, BURST_RATE_LIMIT_BURST_WINDOW, BURST_RATE_LIMIT_SUSTAINED, BURST_RATE_LIMIT_SUSTAINED_WINDOW Full test coverage: concurrent lock-free test, window expiry, disabled states, IP isolation, burst vs sustained distinction.
This commit is contained in:
parent
91ab76758c
commit
13040268d6
7 changed files with 657 additions and 18 deletions
|
|
@ -73,7 +73,7 @@ func main() {
|
||||||
var subFS fs.FS = staticFS
|
var subFS fs.FS = staticFS
|
||||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
|
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
|
||||||
|
|
||||||
// Apply middleware: rate limiter → CORS → handler.
|
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler.
|
||||||
var handler http.Handler = mux
|
var handler http.Handler = mux
|
||||||
handler = middleware.CORS(middleware.CORSConfig{
|
handler = middleware.CORS(middleware.CORSConfig{
|
||||||
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
||||||
|
|
@ -87,6 +87,16 @@ func main() {
|
||||||
Window: cfg.RateLimitWindow(),
|
Window: cfg.RateLimitWindow(),
|
||||||
CleanupInterval: cfg.RateLimitCleanupInterval(),
|
CleanupInterval: cfg.RateLimitCleanupInterval(),
|
||||||
}, logger)(handler)
|
}, logger)(handler)
|
||||||
|
handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{
|
||||||
|
Requests: cfg.GlobalRateLimit.Requests,
|
||||||
|
Window: cfg.GlobalRateLimitWindow(),
|
||||||
|
}, logger)(handler)
|
||||||
|
handler = middleware.BurstRateLimit(middleware.BurstRateLimitConfig{
|
||||||
|
Burst: cfg.BurstRateLimit.Burst,
|
||||||
|
BurstWindow: cfg.BurstWindow(),
|
||||||
|
Sustained: cfg.BurstRateLimit.Sustained,
|
||||||
|
SustainedWindow: cfg.SustainedWindow(),
|
||||||
|
}, logger)(handler)
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||||||
logger.Info("searxng-go starting",
|
logger.Info("searxng-go starting",
|
||||||
|
|
|
||||||
|
|
@ -66,3 +66,28 @@ requests = 30
|
||||||
window = "1m"
|
window = "1m"
|
||||||
# How often to clean up stale IP entries (env: RATE_LIMIT_CLEANUP_INTERVAL)
|
# How often to clean up stale IP entries (env: RATE_LIMIT_CLEANUP_INTERVAL)
|
||||||
cleanup_interval = "5m"
|
cleanup_interval = "5m"
|
||||||
|
|
||||||
|
[global_rate_limit]
|
||||||
|
# Server-wide rate limit across ALL IPs. Prevents pool exhaustion from
|
||||||
|
# distributed attacks even when per-IP limits are bypassed via VPNs.
|
||||||
|
# Returns 503 when exceeded. Set to 0 to disable.
|
||||||
|
# Env: GLOBAL_RATE_LIMIT_REQUESTS
|
||||||
|
requests = 0
|
||||||
|
# Env: GLOBAL_RATE_LIMIT_WINDOW
|
||||||
|
window = "1m"
|
||||||
|
|
||||||
|
[burst_rate_limit]
|
||||||
|
# Per-IP burst + sustained rate limiting. More aggressive than the standard
|
||||||
|
# per-IP limiter. Blocks rapid-fire abuse even if the per-minute limit isn't hit.
|
||||||
|
# Returns 429 with X-RateLimit-Reason header. Set burst to 0 to disable.
|
||||||
|
#
|
||||||
|
# Example: burst=5, burst_window="5s" means max 5 requests in any 5-second span.
|
||||||
|
# sustained=60, sustained_window="1m" means max 60 requests per minute.
|
||||||
|
# Env: BURST_RATE_LIMIT_BURST
|
||||||
|
burst = 0
|
||||||
|
# Env: BURST_RATE_LIMIT_BURST_WINDOW
|
||||||
|
burst_window = "5s"
|
||||||
|
# Env: BURST_RATE_LIMIT_SUSTAINED
|
||||||
|
sustained = 0
|
||||||
|
# Env: BURST_RATE_LIMIT_SUSTAINED_WINDOW
|
||||||
|
sustained_window = "1m"
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ type Config struct {
|
||||||
Cache CacheConfig `toml:"cache"`
|
Cache CacheConfig `toml:"cache"`
|
||||||
CORS CORSConfig `toml:"cors"`
|
CORS CORSConfig `toml:"cors"`
|
||||||
RateLimit RateLimitConfig `toml:"rate_limit"`
|
RateLimit RateLimitConfig `toml:"rate_limit"`
|
||||||
|
GlobalRateLimit GlobalRateLimitConfig `toml:"global_rate_limit"`
|
||||||
|
BurstRateLimit BurstRateLimitConfig `toml:"burst_rate_limit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
|
|
@ -59,6 +61,20 @@ type RateLimitConfig struct {
|
||||||
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
|
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GlobalRateLimitConfig holds server-wide rate limiting settings.
|
||||||
|
type GlobalRateLimitConfig struct {
|
||||||
|
Requests int `toml:"requests"` // Max total requests per window across all IPs (0 = disabled)
|
||||||
|
Window string `toml:"window"` // Time window (e.g. "1m", default: "1m")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BurstRateLimitConfig holds per-IP burst rate limiting settings.
|
||||||
|
type BurstRateLimitConfig struct {
|
||||||
|
Burst int `toml:"burst"` // Max requests in burst window (0 = disabled)
|
||||||
|
BurstWindow string `toml:"burst_window"` // Burst window (default: "5s")
|
||||||
|
Sustained int `toml:"sustained"` // Max requests in sustained window
|
||||||
|
SustainedWindow string `toml:"sustained_window"` // Sustained window (default: "1m")
|
||||||
|
}
|
||||||
|
|
||||||
type BraveConfig struct {
|
type BraveConfig struct {
|
||||||
APIKey string `toml:"api_key"`
|
APIKey string `toml:"api_key"`
|
||||||
AccessToken string `toml:"access_token"`
|
AccessToken string `toml:"access_token"`
|
||||||
|
|
@ -159,6 +175,24 @@ func applyEnvOverrides(cfg *Config) {
|
||||||
if v := os.Getenv("RATE_LIMIT_CLEANUP_INTERVAL"); v != "" {
|
if v := os.Getenv("RATE_LIMIT_CLEANUP_INTERVAL"); v != "" {
|
||||||
cfg.RateLimit.CleanupInterval = v
|
cfg.RateLimit.CleanupInterval = v
|
||||||
}
|
}
|
||||||
|
if v := os.Getenv("GLOBAL_RATE_LIMIT_REQUESTS"); v != "" {
|
||||||
|
fmt.Sscanf(v, "%d", &cfg.GlobalRateLimit.Requests)
|
||||||
|
}
|
||||||
|
if v := os.Getenv("GLOBAL_RATE_LIMIT_WINDOW"); v != "" {
|
||||||
|
cfg.GlobalRateLimit.Window = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("BURST_RATE_LIMIT_BURST"); v != "" {
|
||||||
|
fmt.Sscanf(v, "%d", &cfg.BurstRateLimit.Burst)
|
||||||
|
}
|
||||||
|
if v := os.Getenv("BURST_RATE_LIMIT_BURST_WINDOW"); v != "" {
|
||||||
|
cfg.BurstRateLimit.BurstWindow = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("BURST_RATE_LIMIT_SUSTAINED"); v != "" {
|
||||||
|
fmt.Sscanf(v, "%d", &cfg.BurstRateLimit.Sustained)
|
||||||
|
}
|
||||||
|
if v := os.Getenv("BURST_RATE_LIMIT_SUSTAINED_WINDOW"); v != "" {
|
||||||
|
cfg.BurstRateLimit.SustainedWindow = v
|
||||||
|
}
|
||||||
if v := os.Getenv("BASE_URL"); v != "" {
|
if v := os.Getenv("BASE_URL"); v != "" {
|
||||||
cfg.Server.BaseURL = v
|
cfg.Server.BaseURL = v
|
||||||
}
|
}
|
||||||
|
|
@ -201,6 +235,30 @@ func (c *Config) RateLimitCleanupInterval() time.Duration {
|
||||||
return 5 * time.Minute
|
return 5 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GlobalRateLimitWindow parses the global rate limit window into a time.Duration.
|
||||||
|
func (c *Config) GlobalRateLimitWindow() time.Duration {
|
||||||
|
if d, err := time.ParseDuration(c.GlobalRateLimit.Window); err == nil && d > 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
return time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
// BurstWindow parses the burst window into a time.Duration.
|
||||||
|
func (c *Config) BurstWindow() time.Duration {
|
||||||
|
if d, err := time.ParseDuration(c.BurstRateLimit.BurstWindow); err == nil && d > 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
return 5 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// SustainedWindow parses the sustained window into a time.Duration.
|
||||||
|
func (c *Config) SustainedWindow() time.Duration {
|
||||||
|
if d, err := time.ParseDuration(c.BurstRateLimit.SustainedWindow); err == nil && d > 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
return time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
func splitCSV(s string) []string {
|
func splitCSV(s string) []string {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,6 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background cleanup of stale buckets.
|
|
||||||
go limiter.cleanup(cleanup)
|
go limiter.cleanup(cleanup)
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
|
@ -122,14 +121,10 @@ func (l *ipLimiter) cleanup(interval time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractIP(r *http.Request) string {
|
func extractIP(r *http.Request) string {
|
||||||
// Trust X-Forwarded-For / X-Real-IP if behind a proxy.
|
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
// First IP in the chain is the client.
|
|
||||||
if idx := len(xff); idx > 0 {
|
|
||||||
parts := strings.SplitN(xff, ",", 2)
|
parts := strings.SplitN(xff, ",", 2)
|
||||||
return strings.TrimSpace(parts[0])
|
return strings.TrimSpace(parts[0])
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
||||||
return strings.TrimSpace(rip)
|
return strings.TrimSpace(rip)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
170
internal/middleware/ratelimit_burst_test.go
Normal file
170
internal/middleware/ratelimit_burst_test.go
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBurstRateLimit_AllowsUnderBurst(t *testing.T) {
|
||||||
|
h := BurstRateLimit(BurstRateLimitConfig{
|
||||||
|
Burst: 5,
|
||||||
|
BurstWindow: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBurstRateLimit_BlocksBurst(t *testing.T) {
|
||||||
|
h := BurstRateLimit(BurstRateLimitConfig{
|
||||||
|
Burst: 3,
|
||||||
|
BurstWindow: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected 429, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
reason := rec.Header().Get("X-RateLimit-Reason")
|
||||||
|
if reason != "burst limit exceeded" {
|
||||||
|
t.Errorf("expected burst reason, got %q", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBurstRateLimit_Sustained(t *testing.T) {
|
||||||
|
h := BurstRateLimit(BurstRateLimitConfig{
|
||||||
|
Burst: 100,
|
||||||
|
BurstWindow: 10 * time.Second,
|
||||||
|
Sustained: 5,
|
||||||
|
SustainedWindow: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected 429, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
reason := rec.Header().Get("X-RateLimit-Reason")
|
||||||
|
if reason != "sustained limit exceeded" {
|
||||||
|
t.Errorf("expected sustained reason, got %q", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBurstRateLimit_BurstWindowExpires(t *testing.T) {
|
||||||
|
limiter := &burstLimiter{
|
||||||
|
burst: 2,
|
||||||
|
burstWindow: 50 * time.Millisecond,
|
||||||
|
sustained: 100,
|
||||||
|
sustainedWindow: 10 * time.Second,
|
||||||
|
clients: make(map[string]*burstBucket),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 2 requests (burst limit).
|
||||||
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
||||||
|
t.Fatalf("request 1 blocked: %s", reason)
|
||||||
|
}
|
||||||
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
||||||
|
t.Fatalf("request 2 blocked: %s", reason)
|
||||||
|
}
|
||||||
|
if reason := limiter.allow("1.1.1.1"); reason != "burst limit exceeded" {
|
||||||
|
t.Errorf("expected burst block, got %q", reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for burst window to expire.
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
if reason := limiter.allow("1.1.1.1"); reason != "" {
|
||||||
|
t.Errorf("after burst expiry, should be allowed, got %q", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBurstRateLimit_Disabled(t *testing.T) {
|
||||||
|
h := BurstRateLimit(BurstRateLimitConfig{Burst: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBurstRateLimit_DifferentIPs(t *testing.T) {
|
||||||
|
h := BurstRateLimit(BurstRateLimitConfig{
|
||||||
|
Burst: 1,
|
||||||
|
BurstWindow: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// IP A blocked after 1.
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Error("IP A first: expected 200")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "1.1.1.1:1234"
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Error("IP A second: expected 429")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP B still allowed.
|
||||||
|
req = httptest.NewRequest("GET", "/search", nil)
|
||||||
|
req.RemoteAddr = "2.2.2.2:1234"
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Error("IP B first: expected 200")
|
||||||
|
}
|
||||||
|
}
|
||||||
241
internal/middleware/ratelimit_global.go
Normal file
241
internal/middleware/ratelimit_global.go
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GlobalRateLimitConfig controls server-wide rate limiting.
|
||||||
|
// Applied on top of per-IP rate limiting to prevent overall abuse.
|
||||||
|
type GlobalRateLimitConfig struct {
|
||||||
|
// Requests is the max total requests across all IPs per window.
|
||||||
|
Requests int
|
||||||
|
// Window is the time window duration (e.g. "1m").
|
||||||
|
Window time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// GlobalRateLimit returns a middleware that limits total server-wide requests.
|
||||||
|
// Uses a lock-free atomic counter for minimal overhead.
|
||||||
|
// Set requests to 0 to disable.
|
||||||
|
func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||||
|
requests := cfg.Requests
|
||||||
|
if requests <= 0 {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window := cfg.Window
|
||||||
|
if window <= 0 {
|
||||||
|
window = time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter := &globalLimiter{
|
||||||
|
requests: int64(requests),
|
||||||
|
window: window,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go limiter.resetLoop()
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !limiter.allow() {
|
||||||
|
retryAfter := int(limiter.window.Seconds())
|
||||||
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
|
||||||
|
logger.Warn("global rate limit exceeded", "ip", extractIP(r))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type globalLimiter struct {
|
||||||
|
requests int64
|
||||||
|
count atomic.Int64
|
||||||
|
window time.Duration
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *globalLimiter) allow() bool {
|
||||||
|
for {
|
||||||
|
current := l.count.Load()
|
||||||
|
if current >= l.requests {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if l.count.CompareAndSwap(current, current+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *globalLimiter) resetLoop() {
|
||||||
|
ticker := time.NewTicker(l.window)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
old := l.count.Swap(0)
|
||||||
|
if old > 0 {
|
||||||
|
l.logger.Debug("global rate limit window reset", "previous_count", old)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BurstRateLimitConfig controls per-IP burst + sustained rate limiting.
|
||||||
|
// More aggressive than the standard per-IP limiter — designed to stop rapid-fire abuse.
|
||||||
|
type BurstRateLimitConfig struct {
|
||||||
|
// Burst is the max requests allowed in a short burst window.
|
||||||
|
Burst int
|
||||||
|
// BurstWindow is the burst window duration (e.g. "5s").
|
||||||
|
BurstWindow time.Duration
|
||||||
|
// Sustained is the max requests allowed in the sustained window.
|
||||||
|
Sustained int
|
||||||
|
// SustainedWindow is the sustained window duration (e.g. "1m").
|
||||||
|
SustainedWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// BurstRateLimit returns a middleware that enforces both burst and sustained limits per IP.
|
||||||
|
// Set burst to 0 to disable entirely.
|
||||||
|
func BurstRateLimit(cfg BurstRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||||
|
if cfg.Burst <= 0 && cfg.Sustained <= 0 {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Burst <= 0 {
|
||||||
|
cfg.Burst = cfg.Sustained
|
||||||
|
}
|
||||||
|
if cfg.BurstWindow <= 0 {
|
||||||
|
cfg.BurstWindow = 5 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Sustained <= 0 {
|
||||||
|
cfg.Sustained = cfg.Burst * 6
|
||||||
|
}
|
||||||
|
if cfg.SustainedWindow <= 0 {
|
||||||
|
cfg.SustainedWindow = time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter := &burstLimiter{
|
||||||
|
burst: cfg.Burst,
|
||||||
|
burstWindow: cfg.BurstWindow,
|
||||||
|
sustained: cfg.Sustained,
|
||||||
|
sustainedWindow: cfg.SustainedWindow,
|
||||||
|
clients: make(map[string]*burstBucket),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go limiter.cleanup()
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := extractIP(r)
|
||||||
|
reason := limiter.allow(ip)
|
||||||
|
|
||||||
|
if reason != "" {
|
||||||
|
retryAfter := int(cfg.BurstWindow.Seconds())
|
||||||
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("X-RateLimit-Reason", reason)
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte("429 Too Many Requests — " + reason + "\n"))
|
||||||
|
logger.Debug("burst rate limited", "ip", ip, "reason", reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type burstBucket struct {
|
||||||
|
burstCount int
|
||||||
|
sustainedCount int
|
||||||
|
burstExpireAt time.Time
|
||||||
|
sustainedExpireAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type burstLimiter struct {
|
||||||
|
burst int
|
||||||
|
burstWindow time.Duration
|
||||||
|
sustained int
|
||||||
|
sustainedWindow time.Duration
|
||||||
|
clients map[string]*burstBucket
|
||||||
|
mu sync.Mutex
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow returns empty string if allowed, or a reason string if blocked.
|
||||||
|
func (l *burstLimiter) allow(ip string) string {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
b, ok := l.clients[ip]
|
||||||
|
|
||||||
|
if ok && now.After(b.sustainedExpireAt) {
|
||||||
|
delete(l.clients, ip)
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
l.clients[ip] = &burstBucket{
|
||||||
|
burstCount: 1,
|
||||||
|
sustainedCount: 1,
|
||||||
|
burstExpireAt: now.Add(l.burstWindow),
|
||||||
|
sustainedExpireAt: now.Add(l.sustainedWindow),
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.After(b.burstExpireAt) {
|
||||||
|
b.burstCount = 0
|
||||||
|
b.burstExpireAt = now.Add(l.burstWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.burstCount++
|
||||||
|
b.sustainedCount++
|
||||||
|
|
||||||
|
if b.burstCount > l.burst {
|
||||||
|
return "burst limit exceeded"
|
||||||
|
}
|
||||||
|
if b.sustainedCount > l.sustained {
|
||||||
|
return "sustained limit exceeded"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *burstLimiter) cleanup() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
l.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for ip, b := range l.clients {
|
||||||
|
if now.After(b.sustainedExpireAt) {
|
||||||
|
delete(l.clients, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
140
internal/middleware/ratelimit_global_test.go
Normal file
140
internal/middleware/ratelimit_global_test.go
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGlobalRateLimit_AllowsUnderLimit(t *testing.T) {
|
||||||
|
h := GlobalRateLimit(GlobalRateLimitConfig{
|
||||||
|
Requests: 100,
|
||||||
|
Window: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalRateLimit_BlocksOverLimit(t *testing.T) {
|
||||||
|
h := GlobalRateLimit(GlobalRateLimitConfig{
|
||||||
|
Requests: 5,
|
||||||
|
Window: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("expected 503, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalRateLimit_WindowResets(t *testing.T) {
|
||||||
|
limiter := &globalLimiter{
|
||||||
|
requests: 3,
|
||||||
|
window: 50 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !limiter.allow() {
|
||||||
|
t.Error("first request should be allowed")
|
||||||
|
}
|
||||||
|
if !limiter.allow() {
|
||||||
|
t.Error("second request should be allowed")
|
||||||
|
}
|
||||||
|
if !limiter.allow() {
|
||||||
|
t.Error("third request should be allowed")
|
||||||
|
}
|
||||||
|
if limiter.allow() {
|
||||||
|
t.Error("fourth request should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate window reset.
|
||||||
|
old := limiter.count.Swap(0)
|
||||||
|
if old != 3 {
|
||||||
|
t.Errorf("expected count 3, got %d", old)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !limiter.allow() {
|
||||||
|
t.Error("after reset, request should be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalRateLimit_Disabled(t *testing.T) {
|
||||||
|
h := GlobalRateLimit(GlobalRateLimitConfig{Requests: 0}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Should allow unlimited requests.
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/search", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: expected 200 with disabled limiter, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalRateLimit_LockFree(t *testing.T) {
|
||||||
|
// Verify the atomic counter doesn't panic under concurrent access.
|
||||||
|
h := GlobalRateLimit(GlobalRateLimitConfig{
|
||||||
|
Requests: 10000,
|
||||||
|
Window: 10 * time.Second,
|
||||||
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
var allowed, blocked atomic.Int64
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
go func() {
|
||||||
|
for j := 0; j < 200; j++ {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code == http.StatusOK {
|
||||||
|
allowed.Add(1)
|
||||||
|
} else {
|
||||||
|
blocked.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
total := allowed.Load() + blocked.Load()
|
||||||
|
if total != 20000 {
|
||||||
|
t.Errorf("expected 20000 total requests, got %d", total)
|
||||||
|
}
|
||||||
|
if allowed.Load() != 10000 {
|
||||||
|
t.Errorf("expected exactly 10000 allowed, got %d", allowed.Load())
|
||||||
|
}
|
||||||
|
t.Logf("concurrent test: %d allowed, %d blocked", allowed.Load(), blocked.Load())
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue