Update LICENSE file and add AGPL header to all source files. AGPLv3 ensures that if someone runs Kafka as a network service and modifies it, they must release their source code under the same license.
257 lines
6.3 KiB
Go
257 lines
6.3 KiB
Go
// kafka — a privacy-respecting metasearch engine
|
|
// Copyright (C) 2026-present metamorphosis-dev
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"log/slog"
|
|
)
|
|
|
|
// GlobalRateLimitConfig controls server-wide rate limiting.
|
|
// Applied on top of per-IP rate limiting to prevent overall abuse.
|
|
type GlobalRateLimitConfig struct {
|
|
// Requests is the max total requests across all IPs per window.
|
|
Requests int
|
|
// Window is the time window duration (e.g. "1m").
|
|
Window time.Duration
|
|
}
|
|
|
|
// GlobalRateLimit returns a middleware that limits total server-wide requests.
|
|
// Uses a lock-free atomic counter for minimal overhead.
|
|
// Set requests to 0 to disable.
|
|
func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
|
requests := cfg.Requests
|
|
if requests <= 0 {
|
|
return func(next http.Handler) http.Handler {
|
|
return next
|
|
}
|
|
}
|
|
|
|
window := cfg.Window
|
|
if window <= 0 {
|
|
window = time.Minute
|
|
}
|
|
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
|
|
limiter := &globalLimiter{
|
|
requests: int64(requests),
|
|
window: window,
|
|
logger: logger,
|
|
}
|
|
|
|
go limiter.resetLoop()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !limiter.allow() {
|
|
retryAfter := int(limiter.window.Seconds())
|
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
|
|
logger.Warn("global rate limit exceeded", "ip", extractIP(r))
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type globalLimiter struct {
|
|
requests int64
|
|
count atomic.Int64
|
|
window time.Duration
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (l *globalLimiter) allow() bool {
|
|
for {
|
|
current := l.count.Load()
|
|
if current >= l.requests {
|
|
return false
|
|
}
|
|
if l.count.CompareAndSwap(current, current+1) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *globalLimiter) resetLoop() {
|
|
ticker := time.NewTicker(l.window)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
old := l.count.Swap(0)
|
|
if old > 0 {
|
|
l.logger.Debug("global rate limit window reset", "previous_count", old)
|
|
}
|
|
}
|
|
}
|
|
|
|
// BurstRateLimitConfig controls per-IP burst + sustained rate limiting.
|
|
// More aggressive than the standard per-IP limiter — designed to stop rapid-fire abuse.
|
|
type BurstRateLimitConfig struct {
|
|
// Burst is the max requests allowed in a short burst window.
|
|
Burst int
|
|
// BurstWindow is the burst window duration (e.g. "5s").
|
|
BurstWindow time.Duration
|
|
// Sustained is the max requests allowed in the sustained window.
|
|
Sustained int
|
|
// SustainedWindow is the sustained window duration (e.g. "1m").
|
|
SustainedWindow time.Duration
|
|
}
|
|
|
|
// BurstRateLimit returns a middleware that enforces both burst and sustained limits per IP.
|
|
// Set burst to 0 to disable entirely.
|
|
func BurstRateLimit(cfg BurstRateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
|
if cfg.Burst <= 0 && cfg.Sustained <= 0 {
|
|
return func(next http.Handler) http.Handler {
|
|
return next
|
|
}
|
|
}
|
|
|
|
if cfg.Burst <= 0 {
|
|
cfg.Burst = cfg.Sustained
|
|
}
|
|
if cfg.BurstWindow <= 0 {
|
|
cfg.BurstWindow = 5 * time.Second
|
|
}
|
|
if cfg.Sustained <= 0 {
|
|
cfg.Sustained = cfg.Burst * 6
|
|
}
|
|
if cfg.SustainedWindow <= 0 {
|
|
cfg.SustainedWindow = time.Minute
|
|
}
|
|
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
|
|
limiter := &burstLimiter{
|
|
burst: cfg.Burst,
|
|
burstWindow: cfg.BurstWindow,
|
|
sustained: cfg.Sustained,
|
|
sustainedWindow: cfg.SustainedWindow,
|
|
clients: make(map[string]*burstBucket),
|
|
logger: logger,
|
|
}
|
|
|
|
go limiter.cleanup()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := extractIP(r)
|
|
reason := limiter.allow(ip)
|
|
|
|
if reason != "" {
|
|
retryAfter := int(cfg.BurstWindow.Seconds())
|
|
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
w.Header().Set("X-RateLimit-Reason", reason)
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_, _ = w.Write([]byte("429 Too Many Requests — " + reason + "\n"))
|
|
logger.Debug("burst rate limited", "ip", ip, "reason", reason)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type burstBucket struct {
|
|
burstCount int
|
|
sustainedCount int
|
|
burstExpireAt time.Time
|
|
sustainedExpireAt time.Time
|
|
}
|
|
|
|
type burstLimiter struct {
|
|
burst int
|
|
burstWindow time.Duration
|
|
sustained int
|
|
sustainedWindow time.Duration
|
|
clients map[string]*burstBucket
|
|
mu sync.Mutex
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// allow returns empty string if allowed, or a reason string if blocked.
|
|
func (l *burstLimiter) allow(ip string) string {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
b, ok := l.clients[ip]
|
|
|
|
if ok && now.After(b.sustainedExpireAt) {
|
|
delete(l.clients, ip)
|
|
ok = false
|
|
}
|
|
|
|
if !ok {
|
|
l.clients[ip] = &burstBucket{
|
|
burstCount: 1,
|
|
sustainedCount: 1,
|
|
burstExpireAt: now.Add(l.burstWindow),
|
|
sustainedExpireAt: now.Add(l.sustainedWindow),
|
|
}
|
|
return ""
|
|
}
|
|
|
|
if now.After(b.burstExpireAt) {
|
|
b.burstCount = 0
|
|
b.burstExpireAt = now.Add(l.burstWindow)
|
|
}
|
|
|
|
b.burstCount++
|
|
b.sustainedCount++
|
|
|
|
if b.burstCount > l.burst {
|
|
return "burst limit exceeded"
|
|
}
|
|
if b.sustainedCount > l.sustained {
|
|
return "sustained limit exceeded"
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (l *burstLimiter) cleanup() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
l.mu.Lock()
|
|
now := time.Now()
|
|
for ip, b := range l.clients {
|
|
if now.After(b.sustainedExpireAt) {
|
|
delete(l.clients, ip)
|
|
}
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|