Merge branch 'security/hardening-sast-fixes'
Some checks failed
Build and Push Docker Image / build-and-push (push) Failing after 7s
Mirror to GitHub / mirror (push) Failing after 3s
Tests / test (push) Failing after 19s

This commit is contained in:
Franz Kafka 2026-03-22 16:31:57 +00:00
commit 5884c080fd
24 changed files with 422 additions and 70 deletions

View file

@ -95,8 +95,9 @@ func main() {
var subFS fs.FS = staticFS var subFS fs.FS = staticFS
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS)))) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler. // Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → security headers → handler.
var handler http.Handler = mux var handler http.Handler = mux
handler = middleware.SecurityHeaders(middleware.SecurityHeadersConfig{})(handler)
handler = middleware.CORS(middleware.CORSConfig{ handler = middleware.CORS(middleware.CORSConfig{
AllowedOrigins: cfg.CORS.AllowedOrigins, AllowedOrigins: cfg.CORS.AllowedOrigins,
AllowedMethods: cfg.CORS.AllowedMethods, AllowedMethods: cfg.CORS.AllowedMethods,
@ -108,6 +109,7 @@ func main() {
Requests: cfg.RateLimit.Requests, Requests: cfg.RateLimit.Requests,
Window: cfg.RateLimitWindow(), Window: cfg.RateLimitWindow(),
CleanupInterval: cfg.RateLimitCleanupInterval(), CleanupInterval: cfg.RateLimitCleanupInterval(),
TrustedProxies: cfg.RateLimit.TrustedProxies,
}, logger)(handler) }, logger)(handler)
handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{ handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{
Requests: cfg.GlobalRateLimit.Requests, Requests: cfg.GlobalRateLimit.Requests,

View file

@ -23,6 +23,7 @@ import (
"time" "time"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/metamorphosis-dev/kafka/internal/util"
) )
// Config is the top-level configuration for the kafka service. // Config is the top-level configuration for the kafka service.
@ -77,6 +78,7 @@ type RateLimitConfig struct {
Requests int `toml:"requests"` // Max requests per window (default: 30) Requests int `toml:"requests"` // Max requests per window (default: 30)
Window string `toml:"window"` // Time window (e.g. "1m", default: "1m") Window string `toml:"window"` // Time window (e.g. "1m", default: "1m")
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m") CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
TrustedProxies []string `toml:"trusted_proxies"` // CIDRs allowed to set X-Forwarded-For
} }
// GlobalRateLimitConfig holds server-wide rate limiting settings. // GlobalRateLimitConfig holds server-wide rate limiting settings.
@ -120,9 +122,36 @@ func Load(path string) (*Config, error) {
} }
applyEnvOverrides(cfg) applyEnvOverrides(cfg)
if err := validateConfig(cfg); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
return cfg, nil return cfg, nil
} }
// validateConfig checks security-critical config values at startup.
func validateConfig(cfg *Config) error {
if cfg.Server.BaseURL != "" {
if err := util.ValidatePublicURL(cfg.Server.BaseURL); err != nil {
return fmt.Errorf("server.base_url: %w", err)
}
}
if cfg.Server.SourceURL != "" {
if err := util.ValidatePublicURL(cfg.Server.SourceURL); err != nil {
return fmt.Errorf("server.source_url: %w", err)
}
}
if cfg.Upstream.URL != "" {
// Validate scheme and well-formedness, but allow private IPs
// since self-hosted deployments commonly use localhost/internal addresses.
if _, err := util.SafeURLScheme(cfg.Upstream.URL); err != nil {
return fmt.Errorf("upstream.url: %w", err)
}
}
return nil
}
func defaultConfig() *Config { func defaultConfig() *Config {
return &Config{ return &Config{
Server: ServerConfig{ Server: ServerConfig{

View file

@ -75,8 +75,8 @@ func (e *ArxivEngine) Search(ctx context.Context, req contracts.SearchRequest) (
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status %d", resp.StatusCode)
} }
raw, err := io.ReadAll(resp.Body) raw, err := io.ReadAll(resp.Body)

View file

@ -68,8 +68,8 @@ func (e *BingEngine) Search(ctx context.Context, req contracts.SearchRequest) (c
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status %d", resp.StatusCode)
} }
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")

View file

@ -45,8 +45,8 @@ func (e *BraveEngine) Search(ctx context.Context, req contracts.SearchRequest) (
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("brave error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("brave error: status %d", resp.StatusCode)
} }
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024)) body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))

View file

@ -127,8 +127,8 @@ func (e *BraveAPIEngine) Search(ctx context.Context, req contracts.SearchRequest
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status %d", resp.StatusCode)
} }
var api struct { var api struct {

View file

@ -39,7 +39,7 @@ func TestBraveEngine_GatingAndHeader(t *testing.T) {
}) })
client := &http.Client{Transport: transport} client := &http.Client{Transport: transport}
engine := &BraveEngine{ engine := &BraveAPIEngine{
client: client, client: client,
apiKey: wantAPIKey, apiKey: wantAPIKey,
accessGateToken: wantToken, accessGateToken: wantToken,

View file

@ -63,8 +63,8 @@ func (e *CrossrefEngine) Search(ctx context.Context, req contracts.SearchRequest
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status %d", resp.StatusCode)
} }
var api struct { var api struct {

View file

@ -63,8 +63,8 @@ func (e *DuckDuckGoEngine) Search(ctx context.Context, req contracts.SearchReque
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status %d", resp.StatusCode)
} }
results, err := parseDuckDuckGoHTML(resp.Body) results, err := parseDuckDuckGoHTML(resp.Body)

View file

@ -66,8 +66,8 @@ func (e *GitHubEngine) Search(ctx context.Context, req contracts.SearchRequest)
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("github api error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("github api error: status %d", resp.StatusCode)
} }
var data struct { var data struct {

View file

@ -28,20 +28,10 @@ import (
"github.com/metamorphosis-dev/kafka/internal/contracts" "github.com/metamorphosis-dev/kafka/internal/contracts"
) )
// GSA User-Agent pool — these are Google Search Appliance identifiers // googleUserAgent is an honest User-Agent identifying the metasearch engine.
// that Google trusts for enterprise search appliance traffic. // Using a spoofed GSA User-Agent violates Google's Terms of Service and
var gsaUserAgents = []string{ // risks permanent IP blocking.
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_5_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1", var googleUserAgent = "Kafka/0.1 (compatible; +https://github.com/metamorphosis-dev/kafka)"
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_6_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_0_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_5_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
}
func gsaUA() string {
return gsaUserAgents[0] // deterministic for now; could rotate
}
type GoogleEngine struct { type GoogleEngine struct {
client *http.Client client *http.Client
@ -70,7 +60,7 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest)
if err != nil { if err != nil {
return contracts.SearchResponse{}, err return contracts.SearchResponse{}, err
} }
httpReq.Header.Set("User-Agent", gsaUA()) httpReq.Header.Set("User-Agent", googleUserAgent)
httpReq.Header.Set("Accept", "*/*") httpReq.Header.Set("Accept", "*/*")
httpReq.AddCookie(&http.Cookie{Name: "CONSENT", Value: "YES+"}) httpReq.AddCookie(&http.Cookie{Name: "CONSENT", Value: "YES+"})
@ -95,8 +85,8 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("google error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("google error: status %d", resp.StatusCode)
} }
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024)) body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))

View file

@ -124,8 +124,8 @@ func (e *QwantEngine) searchWebAPI(ctx context.Context, req contracts.SearchRequ
} }
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status %d", resp.StatusCode)
} }
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
@ -253,8 +253,8 @@ func (e *QwantEngine) searchWebLite(ctx context.Context, req contracts.SearchReq
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status %d", resp.StatusCode)
} }
doc, err := goquery.NewDocumentFromReader(resp.Body) doc, err := goquery.NewDocumentFromReader(resp.Body)

View file

@ -62,8 +62,8 @@ func (e *RedditEngine) Search(ctx context.Context, req contracts.SearchRequest)
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status %d", resp.StatusCode)
} }
var data struct { var data struct {

View file

@ -134,8 +134,8 @@ func (e *WikipediaEngine) Search(ctx context.Context, req contracts.SearchReques
}, nil }, nil
} }
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status %d", resp.StatusCode)
} }
var api struct { var api struct {

View file

@ -77,8 +77,8 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status %d", resp.StatusCode)
} }
var apiResp youtubeSearchResponse var apiResp youtubeSearchResponse
@ -87,7 +87,7 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
} }
if apiResp.Error != nil { if apiResp.Error != nil {
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: %s", apiResp.Error.Message) return contracts.SearchResponse{}, fmt.Errorf("youtube api error: code %d", apiResp.Error.Code)
} }
results := make([]contracts.MainResult, 0, len(apiResp.Items)) results := make([]contracts.MainResult, 0, len(apiResp.Items))

View file

@ -42,7 +42,8 @@ type CORSConfig struct {
func CORS(cfg CORSConfig) func(http.Handler) http.Handler { func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
origins := cfg.AllowedOrigins origins := cfg.AllowedOrigins
if len(origins) == 0 { if len(origins) == 0 {
origins = []string{"*"} // Default: no CORS headers. Explicitly configure origins to enable.
origins = nil
} }
methods := cfg.AllowedMethods methods := cfg.AllowedMethods
@ -70,6 +71,7 @@ func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
// Determine the allowed origin for this request. // Determine the allowed origin for this request.
// If no origins are configured, CORS is disabled entirely — no headers are set.
allowedOrigin := "" allowedOrigin := ""
for _, o := range origins { for _, o := range origins {
if o == "*" { if o == "*" {

View file

@ -27,10 +27,14 @@ import (
"log/slog" "log/slog"
) )
// RateLimitConfig controls per-IP rate limiting.
type RateLimitConfig struct { type RateLimitConfig struct {
Requests int Requests int
Window time.Duration Window time.Duration
CleanupInterval time.Duration CleanupInterval time.Duration
// TrustedProxies is a list of CIDR ranges that are allowed to set
// X-Forwarded-For / X-Real-IP. If empty, only r.RemoteAddr is used.
TrustedProxies []string
} }
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
@ -53,18 +57,30 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http
logger = slog.Default() logger = slog.Default()
} }
// Parse trusted proxy CIDRs.
var trustedNets []*net.IPNet
for _, cidr := range cfg.TrustedProxies {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
logger.Warn("invalid trusted proxy CIDR, skipping", "cidr", cidr, "error", err)
continue
}
trustedNets = append(trustedNets, network)
}
limiter := &ipLimiter{ limiter := &ipLimiter{
requests: requests, requests: requests,
window: window, window: window,
clients: make(map[string]*bucket), clients: make(map[string]*bucket),
logger: logger, logger: logger,
trusted: trustedNets,
} }
go limiter.cleanup(cleanup) go limiter.cleanup(cleanup)
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractIP(r) ip := limiter.extractIP(r)
if !limiter.allow(ip) { if !limiter.allow(ip) {
retryAfter := int(limiter.window.Seconds()) retryAfter := int(limiter.window.Seconds())
@ -92,6 +108,7 @@ type ipLimiter struct {
clients map[string]*bucket clients map[string]*bucket
mu sync.Mutex mu sync.Mutex
logger *slog.Logger logger *slog.Logger
trusted []*net.IPNet
} }
func (l *ipLimiter) allow(ip string) bool { func (l *ipLimiter) allow(ip string) bool {
@ -129,18 +146,48 @@ func (l *ipLimiter) cleanup(interval time.Duration) {
} }
} }
func extractIP(r *http.Request) string { // extractIP extracts the client IP from the request.
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // If trusted proxy CIDRs are configured, X-Forwarded-For is only used when
parts := strings.SplitN(xff, ",", 2) // the direct connection comes from a trusted proxy. Otherwise, only RemoteAddr is used.
return strings.TrimSpace(parts[0]) func (l *ipLimiter) extractIP(r *http.Request) string {
} return extractIP(r, l.trusted...)
if rip := r.Header.Get("X-Real-IP"); rip != "" { }
return strings.TrimSpace(rip)
func extractIP(r *http.Request, trusted ...*net.IPNet) string {
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteIP = r.RemoteAddr
} }
host, _, err := net.SplitHostPort(r.RemoteAddr) // Check if the direct connection is from a trusted proxy.
if err != nil { isTrusted := false
return r.RemoteAddr if len(trusted) > 0 {
ip := net.ParseIP(remoteIP)
if ip != nil {
for _, network := range trusted {
if network.Contains(ip) {
isTrusted = true
break
}
}
}
} }
return host
if isTrusted {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
candidate := strings.TrimSpace(parts[0])
if net.ParseIP(candidate) != nil {
return candidate
}
}
if rip := r.Header.Get("X-Real-IP"); rip != "" {
candidate := strings.TrimSpace(rip)
if net.ParseIP(candidate) != nil {
return candidate
}
}
}
return remoteIP
} }

View file

@ -71,7 +71,7 @@ func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.H
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n")) _, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
logger.Warn("global rate limit exceeded", "ip", extractIP(r)) logger.Warn("global rate limit exceeded", "remote", r.RemoteAddr)
return return
} }

View file

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -93,8 +94,9 @@ func TestRateLimit_DifferentIPs(t *testing.T) {
func TestRateLimit_XForwardedFor(t *testing.T) { func TestRateLimit_XForwardedFor(t *testing.T) {
h := RateLimit(RateLimitConfig{ h := RateLimit(RateLimitConfig{
Requests: 1, Requests: 1,
Window: 10 * time.Second, Window: 10 * time.Second,
TrustedProxies: []string{"10.0.0.0/8"},
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))
@ -143,17 +145,27 @@ func TestRateLimit_WindowExpires(t *testing.T) {
} }
func TestExtractIP(t *testing.T) { func TestExtractIP(t *testing.T) {
// Trusted proxy: loopback
loopback := mustParseCIDR("127.0.0.0/8")
privateNet := mustParseCIDR("10.0.0.0/8")
tests := []struct { tests := []struct {
name string name string
xff string xff string
realIP string realIP string
remote string remote string
trusted []*net.IPNet
expected string expected string
}{ }{
{"xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", "203.0.113.50"}, // No trusted proxies → always use RemoteAddr.
{"real_ip", "", "203.0.113.50", "10.0.0.1:1234", "203.0.113.50"}, {"no_trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", nil, "10.0.0.1"},
{"remote", "", "", "1.2.3.4:5678", "1.2.3.4"}, {"no_trusted_real", "", "203.0.113.50", "10.0.0.1:1234", nil, "10.0.0.1"},
{"xff_over_real", "203.0.113.50", "10.0.0.1", "10.0.0.1:1234", "203.0.113.50"}, {"no_trusted_remote", "", "", "1.2.3.4:5678", nil, "1.2.3.4"},
// Trusted proxy → XFF is respected.
{"trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
{"trusted_real_ip", "", "203.0.113.50", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
// Untrusted remote → XFF ignored even if present.
{"untrusted_xff", "203.0.113.50, 10.0.0.1", "", "1.2.3.4:5678", []*net.IPNet{loopback}, "1.2.3.4"},
} }
for _, tt := range tests { for _, tt := range tests {
@ -167,9 +179,17 @@ func TestExtractIP(t *testing.T) {
} }
req.RemoteAddr = tt.remote req.RemoteAddr = tt.remote
if got := extractIP(req); got != tt.expected { if got := extractIP(req, tt.trusted...); got != tt.expected {
t.Errorf("extractIP() = %q, want %q", got, tt.expected) t.Errorf("extractIP() = %q, want %q", got, tt.expected)
} }
}) })
} }
} }
func mustParseCIDR(s string) *net.IPNet {
_, network, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return network
}

View file

@ -0,0 +1,92 @@
// kafka — a privacy-respecting metasearch engine
// Copyright (C) 2026-present metamorphosis-dev
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
package middleware
import (
"net/http"
"strconv"
"strings"
)
// SecurityHeadersConfig controls which security headers are set.
type SecurityHeadersConfig struct {
// FrameOptions controls X-Frame-Options. Default: "DENY".
FrameOptions string
// HSTSMaxAge controls the max-age for Strict-Transport-Security.
// Set to 0 to disable HSTS (useful for local dev). Default: 31536000 (1 year).
HSTSMaxAge int
// HSTSPreloadDomains adds "includeSubDomains; preload" to HSTS.
HSTSPreloadDomains bool
// ReferrerPolicy controls the Referrer-Policy header. Default: "no-referrer".
ReferrerPolicy string
// CSP controls Content-Security-Policy. Default: a restrictive policy.
// Set to "" to disable CSP entirely.
CSP string
}
// SecurityHeaders returns middleware that sets standard HTTP security headers
// on every response.
func SecurityHeaders(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
frameOpts := cfg.FrameOptions
if frameOpts == "" {
frameOpts = "DENY"
}
hstsAge := cfg.HSTSMaxAge
if hstsAge == 0 {
hstsAge = 31536000 // 1 year
}
refPol := cfg.ReferrerPolicy
if refPol == "" {
refPol = "no-referrer"
}
csp := cfg.CSP
if csp == "" {
csp = defaultCSP()
}
hstsValue := "max-age=" + strconv.Itoa(hstsAge)
if cfg.HSTSPreloadDomains {
hstsValue += "; includeSubDomains; preload"
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", frameOpts)
w.Header().Set("Referrer-Policy", refPol)
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Content-Security-Policy", csp)
if hstsAge > 0 {
w.Header().Set("Strict-Transport-Security", hstsValue)
}
next.ServeHTTP(w, r)
})
}
}
// defaultCSP returns a restrictive Content-Security-Policy for the
// metasearch engine.
func defaultCSP() string {
return strings.Join([]string{
"default-src 'self'",
"script-src 'self'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' https: data:",
"connect-src 'self'",
"font-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
}, "; ")
}

View file

@ -26,6 +26,28 @@ import (
var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`) var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`)
// maxQueryLength is the maximum allowed length for the search query.
const maxQueryLength = 1024
// knownEngineNames is the allowlist of valid engine identifiers.
var knownEngineNames = map[string]bool{
"wikipedia": true, "arxiv": true, "crossref": true,
"braveapi": true, "brave": true, "qwant": true,
"duckduckgo": true, "github": true, "reddit": true,
"bing": true, "google": true, "youtube": true,
}
// validateEngines filters engine names against the known registry.
func validateEngines(engines []string) []string {
out := make([]string, 0, len(engines))
for _, e := range engines {
if knownEngineNames[strings.ToLower(e)] {
out = append(out, strings.ToLower(e))
}
}
return out
}
func ParseSearchRequest(r *http.Request) (SearchRequest, error) { func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
// Supports both GET and POST and relies on form values for routing. // Supports both GET and POST and relies on form values for routing.
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
@ -50,6 +72,9 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
if strings.TrimSpace(q) == "" { if strings.TrimSpace(q) == "" {
return SearchRequest{}, errors.New("missing required parameter: q") return SearchRequest{}, errors.New("missing required parameter: q")
} }
if len(q) > maxQueryLength {
return SearchRequest{}, errors.New("query exceeds maximum length")
}
pageno := 1 pageno := 1
if s := strings.TrimSpace(r.FormValue("pageno")); s != "" { if s := strings.TrimSpace(r.FormValue("pageno")); s != "" {
@ -105,6 +130,8 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
// engines is an explicit list of engine names. // engines is an explicit list of engine names.
engines := splitCSV(strings.TrimSpace(r.FormValue("engines"))) engines := splitCSV(strings.TrimSpace(r.FormValue("engines")))
// Validate engine names against known registry to prevent injection.
engines = validateEngines(engines)
// categories and category_<name> params mirror the webadapter parsing. // categories and category_<name> params mirror the webadapter parsing.
// We don't validate against a registry here; we just preserve the requested values. // We don't validate against a registry here; we just preserve the requested values.

View file

@ -44,6 +44,9 @@ func NewClient(baseURL string, timeout time.Duration) (*Client, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid upstream base URL: %w", err) return nil, fmt.Errorf("invalid upstream base URL: %w", err)
} }
if u.Scheme != "http" && u.Scheme != "https" {
return nil, fmt.Errorf("upstream URL must use http or https, got %q", u.Scheme)
}
// Normalize: trim trailing slash to make URL concatenation predictable. // Normalize: trim trailing slash to make URL concatenation predictable.
base := strings.TrimRight(u.String(), "/") base := strings.TrimRight(u.String(), "/")
@ -108,7 +111,7 @@ func (c *Client) SearchJSON(ctx context.Context, req contracts.SearchRequest, en
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return contracts.SearchResponse{}, fmt.Errorf("upstream search failed: status=%d body=%q", resp.StatusCode, string(body)) return contracts.SearchResponse{}, fmt.Errorf("upstream search failed with status %d", resp.StatusCode)
} }
// Decode upstream JSON into our contract types. // Decode upstream JSON into our contract types.

127
internal/util/validate.go Normal file
View file

@ -0,0 +1,127 @@
// kafka — a privacy-respecting metasearch engine
// Copyright (C) 2026-present metamorphosis-dev
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License,// or (at your option) any later version.
package util
import (
"fmt"
"net"
"net/url"
"strings"
)
// SafeURLScheme validates that a URL is well-formed and uses an acceptable scheme.
// Returns the parsed URL on success, or an error.
func SafeURLScheme(raw string) (*url.URL, error) {
u, err := url.Parse(raw)
if err != nil {
return nil, err
}
if u.Scheme != "http" && u.Scheme != "https" {
return nil, fmt.Errorf("URL must use http or https, got %q", u.Scheme)
}
return u, nil
}
// IsPrivateIP returns true if the IP address is in a private, loopback,
// link-local, or otherwise non-routable range.
func IsPrivateIP(host string) bool {
// Strip port if present.
h, _, err := net.SplitHostPort(host)
if err != nil {
h = host
}
// Resolve hostname to IPs.
ips, err := net.LookupIP(h)
if err != nil || len(ips) == 0 {
// If we can't resolve, reject to be safe.
return true
}
for _, ip := range ips {
if isPrivateIPAddr(ip) {
return true
}
}
return false
}
func isPrivateIPAddr(ip net.IP) bool {
privateRanges := []struct {
network *net.IPNet
}{
// Loopback
{mustParseCIDR("127.0.0.0/8")},
{mustParseCIDR("::1/128")},
// RFC 1918
{mustParseCIDR("10.0.0.0/8")},
{mustParseCIDR("172.16.0.0/12")},
{mustParseCIDR("192.168.0.0/16")},
// RFC 6598 (Carrier-grade NAT)
{mustParseCIDR("100.64.0.0/10")},
// Link-local
{mustParseCIDR("169.254.0.0/16")},
{mustParseCIDR("fe80::/10")},
// IPv6 unique local
{mustParseCIDR("fc00::/7")},
// IPv4-mapped IPv6 loopback
{mustParseCIDR("::ffff:127.0.0.0/104")},
}
for _, r := range privateRanges {
if r.network.Contains(ip) {
return true
}
}
return false
}
func mustParseCIDR(s string) *net.IPNet {
_, network, err := net.ParseCIDR(s)
if err != nil {
panic(fmt.Sprintf("validate: invalid CIDR %q: %v", s, err))
}
return network
}
// ValidatePublicURL checks that a URL is well-formed, uses http or https,
// and does not point to a private/reserved IP range.
func ValidatePublicURL(raw string) error {
u, err := url.Parse(raw)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("URL must use http or https, got %q", u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("URL must have a host")
}
if IsPrivateIP(u.Host) {
return fmt.Errorf("URL points to a private or reserved address: %s", u.Host)
}
return nil
}
// SanitizeResultURL ensures a URL is safe for rendering in an href attribute.
// It rejects javascript:, data:, vbscript: and other dangerous schemes.
func SanitizeResultURL(raw string) string {
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return ""
}
switch strings.ToLower(u.Scheme) {
case "http", "https", "":
return raw
default:
return ""
}
}

View file

@ -18,6 +18,7 @@ package views
import ( import (
"embed" "embed"
"encoding/xml"
"html/template" "html/template"
"io/fs" "io/fs"
"net/http" "net/http"
@ -25,6 +26,7 @@ import (
"strings" "strings"
"github.com/metamorphosis-dev/kafka/internal/contracts" "github.com/metamorphosis-dev/kafka/internal/contracts"
"github.com/metamorphosis-dev/kafka/internal/util"
) )
//go:embed all:templates //go:embed all:templates
@ -122,15 +124,20 @@ func StaticFS() (fs.FS, error) {
return fs.Sub(staticFS, "static") return fs.Sub(staticFS, "static")
} }
// OpenSearchXML returns the OpenSearch description XML with {baseUrl} // OpenSearchXML returns the OpenSearch description XML with the base URL
// replaced by the provided base URL. // safely embedded via xml.EscapeText (no raw string interpolation).
func OpenSearchXML(baseURL string) ([]byte, error) { func OpenSearchXML(baseURL string) ([]byte, error) {
tmplFS, _ := fs.Sub(templatesFS, "templates") tmplFS, _ := fs.Sub(templatesFS, "templates")
data, err := fs.ReadFile(tmplFS, "opensearch.xml") data, err := fs.ReadFile(tmplFS, "opensearch.xml")
if err != nil { if err != nil {
return nil, err return nil, err
} }
result := strings.ReplaceAll(string(data), "{baseUrl}", baseURL)
var buf strings.Builder
xml.Escape(&buf, []byte(baseURL))
escapedBaseURL := buf.String()
result := strings.ReplaceAll(string(data), "{baseUrl}", escapedBaseURL)
return []byte(result), nil return []byte(result), nil
} }
@ -190,6 +197,12 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
if r.Template == "videos" { if r.Template == "videos" {
tmplName = "video_item" tmplName = "video_item"
} }
// Sanitize URLs to prevent javascript:/data: scheme injection.
if r.URL != nil {
safe := util.SanitizeResultURL(*r.URL)
r.URL = &safe
}
r.Thumbnail = util.SanitizeResultURL(r.Thumbnail)
pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName} pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName}
} }
@ -213,7 +226,7 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
iv.Title = v iv.Title = v
} }
if v, ok := ib["img_src"].(string); ok { if v, ok := ib["img_src"].(string); ok {
iv.ImgSrc = v iv.ImgSrc = util.SanitizeResultURL(v)
} }
if iv.Title != "" || iv.Content != "" { if iv.Title != "" || iv.Content != "" {
pd.Infoboxes = append(pd.Infoboxes, iv) pd.Infoboxes = append(pd.Infoboxes, iv)