feat: add CORS and rate limiting middleware
CORS: - Configurable allowed origins (wildcard "*" or specific domains) - Handles OPTIONS preflight with configurable methods, headers, max-age - Exposed headers support for browser API access - Env override: CORS_ALLOWED_ORIGINS Rate Limiting: - In-memory per-IP sliding window counter - Configurable request limit and time window - Background goroutine cleans up stale IP entries - HTTP 429 with Retry-After header when exceeded - Extracts real IP from X-Forwarded-For and X-Real-IP (proxy-aware) - Env overrides: RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW, RATE_LIMIT_CLEANUP_INTERVAL - Set requests=0 in config to disable Both wired into main.go as middleware chain: rate_limit → cors → handler. Config example updated with [cors] and [rate_limit] sections. Full test coverage for both middleware packages.
This commit is contained in:
parent
4c54ed5b56
commit
ebeaeeef21
8 changed files with 616 additions and 6 deletions
142
internal/middleware/ratelimit.go
Normal file
142
internal/middleware/ratelimit.go
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// RateLimitConfig controls per-IP rate limiting using a sliding window counter.
|
||||
type RateLimitConfig struct {
|
||||
// Requests is the max number of requests allowed per window.
|
||||
Requests int
|
||||
// Window is the time window duration (e.g. "1m").
|
||||
Window time.Duration
|
||||
// CleanupInterval is how often stale entries are purged (default: 5m).
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// RateLimit returns a middleware that limits requests per IP address.
|
||||
// Uses an in-memory sliding window counter. When the limit is exceeded,
|
||||
// responds with HTTP 429 and a Retry-After header.
|
||||
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
requests := cfg.Requests
|
||||
if requests <= 0 {
|
||||
requests = 30
|
||||
}
|
||||
|
||||
window := cfg.Window
|
||||
if window <= 0 {
|
||||
window = time.Minute
|
||||
}
|
||||
|
||||
cleanup := cfg.CleanupInterval
|
||||
if cleanup <= 0 {
|
||||
cleanup = 5 * time.Minute
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
|
||||
limiter := &ipLimiter{
|
||||
requests: requests,
|
||||
window: window,
|
||||
clients: make(map[string]*bucket),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Background cleanup of stale buckets.
|
||||
go limiter.cleanup(cleanup)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := extractIP(r)
|
||||
|
||||
if !limiter.allow(ip) {
|
||||
retryAfter := int(limiter.window.Seconds())
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte("429 Too Many Requests\n"))
|
||||
logger.Debug("rate limited", "ip", ip)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
count int
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
type ipLimiter struct {
|
||||
requests int
|
||||
window time.Duration
|
||||
clients map[string]*bucket
|
||||
mu sync.Mutex
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (l *ipLimiter) allow(ip string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
b, ok := l.clients[ip]
|
||||
|
||||
if !ok || now.After(b.expireAt) {
|
||||
l.clients[ip] = &bucket{
|
||||
count: 1,
|
||||
expireAt: now.Add(l.window),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
b.count++
|
||||
return b.count <= l.requests
|
||||
}
|
||||
|
||||
func (l *ipLimiter) cleanup(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
l.mu.Lock()
|
||||
now := time.Now()
|
||||
for ip, b := range l.clients {
|
||||
if now.After(b.expireAt) {
|
||||
delete(l.clients, ip)
|
||||
}
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func extractIP(r *http.Request) string {
|
||||
// Trust X-Forwarded-For / X-Real-IP if behind a proxy.
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First IP in the chain is the client.
|
||||
if idx := len(xff); idx > 0 {
|
||||
parts := strings.SplitN(xff, ",", 2)
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
||||
return strings.TrimSpace(rip)
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue