diff --git a/cmd/kafka/main.go b/cmd/kafka/main.go index 3a0a80e..f691665 100644 --- a/cmd/kafka/main.go +++ b/cmd/kafka/main.go @@ -95,8 +95,9 @@ func main() { var subFS fs.FS = staticFS mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS)))) - // Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler. + // Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → security headers → handler. var handler http.Handler = mux + handler = middleware.SecurityHeaders(middleware.SecurityHeadersConfig{})(handler) handler = middleware.CORS(middleware.CORSConfig{ AllowedOrigins: cfg.CORS.AllowedOrigins, AllowedMethods: cfg.CORS.AllowedMethods, @@ -108,6 +109,7 @@ func main() { Requests: cfg.RateLimit.Requests, Window: cfg.RateLimitWindow(), CleanupInterval: cfg.RateLimitCleanupInterval(), + TrustedProxies: cfg.RateLimit.TrustedProxies, }, logger)(handler) handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{ Requests: cfg.GlobalRateLimit.Requests, diff --git a/internal/config/config.go b/internal/config/config.go index e5d1fbb..f5a8b9a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,7 @@ import ( "time" "github.com/BurntSushi/toml" + "github.com/metamorphosis-dev/kafka/internal/util" ) // Config is the top-level configuration for the kafka service. @@ -77,6 +78,7 @@ type RateLimitConfig struct { Requests int `toml:"requests"` // Max requests per window (default: 30) Window string `toml:"window"` // Time window (e.g. "1m", default: "1m") CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m") + TrustedProxies []string `toml:"trusted_proxies"` // CIDRs allowed to set X-Forwarded-For } // GlobalRateLimitConfig holds server-wide rate limiting settings. @@ -120,9 +122,36 @@ func Load(path string) (*Config, error) { } applyEnvOverrides(cfg) + + if err := validateConfig(cfg); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + return cfg, nil } +// validateConfig checks security-critical config values at startup. +func validateConfig(cfg *Config) error { + if cfg.Server.BaseURL != "" { + if err := util.ValidatePublicURL(cfg.Server.BaseURL); err != nil { + return fmt.Errorf("server.base_url: %w", err) + } + } + if cfg.Server.SourceURL != "" { + if err := util.ValidatePublicURL(cfg.Server.SourceURL); err != nil { + return fmt.Errorf("server.source_url: %w", err) + } + } + if cfg.Upstream.URL != "" { + // Validate scheme and well-formedness, but allow private IPs + // since self-hosted deployments commonly use localhost/internal addresses. + if _, err := util.SafeURLScheme(cfg.Upstream.URL); err != nil { + return fmt.Errorf("upstream.url: %w", err) + } + } + return nil +} + func defaultConfig() *Config { return &Config{ Server: ServerConfig{ diff --git a/internal/engines/arxiv.go b/internal/engines/arxiv.go index 1347562..2f9cca0 100644 --- a/internal/engines/arxiv.go +++ b/internal/engines/arxiv.go @@ -75,8 +75,8 @@ func (e *ArxivEngine) Search(ctx context.Context, req contracts.SearchRequest) ( defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status %d", resp.StatusCode) } raw, err := io.ReadAll(resp.Body) diff --git a/internal/engines/bing.go b/internal/engines/bing.go index 85c3f65..3b18f7b 100644 --- a/internal/engines/bing.go +++ b/internal/engines/bing.go @@ -68,8 +68,8 @@ func (e *BingEngine) Search(ctx context.Context, req contracts.SearchRequest) (c defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status %d", resp.StatusCode) } contentType := resp.Header.Get("Content-Type") diff --git a/internal/engines/brave.go b/internal/engines/brave.go index cb9313d..da25630 100644 --- a/internal/engines/brave.go +++ b/internal/engines/brave.go @@ -45,8 +45,8 @@ func (e *BraveEngine) Search(ctx context.Context, req contracts.SearchRequest) ( defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("brave error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("brave error: status %d", resp.StatusCode) } body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024)) diff --git a/internal/engines/braveapi.go b/internal/engines/braveapi.go index 1ae6220..830b010 100644 --- a/internal/engines/braveapi.go +++ b/internal/engines/braveapi.go @@ -127,8 +127,8 @@ func (e *BraveAPIEngine) Search(ctx context.Context, req contracts.SearchRequest defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status %d", resp.StatusCode) } var api struct { diff --git a/internal/engines/braveapi_test.go b/internal/engines/braveapi_test.go index 13c7420..ed710ff 100644 --- a/internal/engines/braveapi_test.go +++ b/internal/engines/braveapi_test.go @@ -39,7 +39,7 @@ func TestBraveEngine_GatingAndHeader(t *testing.T) { }) client := &http.Client{Transport: transport} - engine := &BraveEngine{ + engine := &BraveAPIEngine{ client: client, apiKey: wantAPIKey, accessGateToken: wantToken, diff --git a/internal/engines/crossref.go b/internal/engines/crossref.go index cc33759..79e6ab5 100644 --- a/internal/engines/crossref.go +++ b/internal/engines/crossref.go @@ -63,8 +63,8 @@ func (e *CrossrefEngine) Search(ctx context.Context, req contracts.SearchRequest defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status %d", resp.StatusCode) } var api struct { diff --git a/internal/engines/duckduckgo.go b/internal/engines/duckduckgo.go index 158d483..7a71ef4 100644 --- a/internal/engines/duckduckgo.go +++ b/internal/engines/duckduckgo.go @@ -63,8 +63,8 @@ func (e *DuckDuckGoEngine) Search(ctx context.Context, req contracts.SearchReque defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status %d", resp.StatusCode) } results, err := parseDuckDuckGoHTML(resp.Body) diff --git a/internal/engines/github.go b/internal/engines/github.go index f37cddc..d0c9fcc 100644 --- a/internal/engines/github.go +++ b/internal/engines/github.go @@ -66,8 +66,8 @@ func (e *GitHubEngine) Search(ctx context.Context, req contracts.SearchRequest) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("github api error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("github api error: status %d", resp.StatusCode) } var data struct { diff --git a/internal/engines/google.go b/internal/engines/google.go index 77d8549..cea4bd5 100644 --- a/internal/engines/google.go +++ b/internal/engines/google.go @@ -28,20 +28,10 @@ import ( "github.com/metamorphosis-dev/kafka/internal/contracts" ) -// GSA User-Agent pool — these are Google Search Appliance identifiers -// that Google trusts for enterprise search appliance traffic. -var gsaUserAgents = []string{ - "Mozilla/5.0 (iPhone; CPU iPhone OS 17_5_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1", - "Mozilla/5.0 (iPhone; CPU iPhone OS 17_6_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1", - "Mozilla/5.0 (iPhone; CPU iPhone OS 17_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1", - "Mozilla/5.0 (iPhone; CPU iPhone OS 18_0_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1", - "Mozilla/5.0 (iPhone; CPU iPhone OS 18_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1", - "Mozilla/5.0 (iPhone; CPU iPhone OS 18_5_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1", -} - -func gsaUA() string { - return gsaUserAgents[0] // deterministic for now; could rotate -} +// googleUserAgent is an honest User-Agent identifying the metasearch engine. +// Using a spoofed GSA User-Agent violates Google's Terms of Service and +// risks permanent IP blocking. +var googleUserAgent = "Kafka/0.1 (compatible; +https://github.com/metamorphosis-dev/kafka)" type GoogleEngine struct { client *http.Client @@ -70,7 +60,7 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest) if err != nil { return contracts.SearchResponse{}, err } - httpReq.Header.Set("User-Agent", gsaUA()) + httpReq.Header.Set("User-Agent", googleUserAgent) httpReq.Header.Set("Accept", "*/*") httpReq.AddCookie(&http.Cookie{Name: "CONSENT", Value: "YES+"}) @@ -95,8 +85,8 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest) } if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("google error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("google error: status %d", resp.StatusCode) } body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024)) diff --git a/internal/engines/qwant.go b/internal/engines/qwant.go index e15d4f2..7fa963b 100644 --- a/internal/engines/qwant.go +++ b/internal/engines/qwant.go @@ -124,8 +124,8 @@ func (e *QwantEngine) searchWebAPI(ctx context.Context, req contracts.SearchRequ } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status %d", resp.StatusCode) } body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) @@ -253,8 +253,8 @@ func (e *QwantEngine) searchWebLite(ctx context.Context, req contracts.SearchReq defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status %d", resp.StatusCode) } doc, err := goquery.NewDocumentFromReader(resp.Body) diff --git a/internal/engines/reddit.go b/internal/engines/reddit.go index 788f52a..699e7b2 100644 --- a/internal/engines/reddit.go +++ b/internal/engines/reddit.go @@ -62,8 +62,8 @@ func (e *RedditEngine) Search(ctx context.Context, req contracts.SearchRequest) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status %d", resp.StatusCode) } var data struct { diff --git a/internal/engines/wikipedia.go b/internal/engines/wikipedia.go index f29ff74..518d994 100644 --- a/internal/engines/wikipedia.go +++ b/internal/engines/wikipedia.go @@ -134,8 +134,8 @@ func (e *WikipediaEngine) Search(ctx context.Context, req contracts.SearchReques }, nil } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) - return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024)) + return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status %d", resp.StatusCode) } var api struct { diff --git a/internal/engines/youtube.go b/internal/engines/youtube.go index 5946aa4..0c5ff9e 100644 --- a/internal/engines/youtube.go +++ b/internal/engines/youtube.go @@ -77,8 +77,8 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status=%d body=%q", resp.StatusCode, string(body)) + io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status %d", resp.StatusCode) } var apiResp youtubeSearchResponse @@ -87,7 +87,7 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest) } if apiResp.Error != nil { - return contracts.SearchResponse{}, fmt.Errorf("youtube api error: %s", apiResp.Error.Message) + return contracts.SearchResponse{}, fmt.Errorf("youtube api error: code %d", apiResp.Error.Code) } results := make([]contracts.MainResult, 0, len(apiResp.Items)) diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index ee90ab0..d4ecf2a 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -42,7 +42,8 @@ type CORSConfig struct { func CORS(cfg CORSConfig) func(http.Handler) http.Handler { origins := cfg.AllowedOrigins if len(origins) == 0 { - origins = []string{"*"} + // Default: no CORS headers. Explicitly configure origins to enable. + origins = nil } methods := cfg.AllowedMethods @@ -70,6 +71,7 @@ func CORS(cfg CORSConfig) func(http.Handler) http.Handler { origin := r.Header.Get("Origin") // Determine the allowed origin for this request. + // If no origins are configured, CORS is disabled entirely — no headers are set. allowedOrigin := "" for _, o := range origins { if o == "*" { diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index 78774f2..6f662fd 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -27,10 +27,14 @@ import ( "log/slog" ) +// RateLimitConfig controls per-IP rate limiting. type RateLimitConfig struct { Requests int Window time.Duration CleanupInterval time.Duration + // TrustedProxies is a list of CIDR ranges that are allowed to set + // X-Forwarded-For / X-Real-IP. If empty, only r.RemoteAddr is used. + TrustedProxies []string } func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler { @@ -53,18 +57,30 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http logger = slog.Default() } + // Parse trusted proxy CIDRs. + var trustedNets []*net.IPNet + for _, cidr := range cfg.TrustedProxies { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + logger.Warn("invalid trusted proxy CIDR, skipping", "cidr", cidr, "error", err) + continue + } + trustedNets = append(trustedNets, network) + } + limiter := &ipLimiter{ requests: requests, window: window, clients: make(map[string]*bucket), logger: logger, + trusted: trustedNets, } go limiter.cleanup(cleanup) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := extractIP(r) + ip := limiter.extractIP(r) if !limiter.allow(ip) { retryAfter := int(limiter.window.Seconds()) @@ -92,6 +108,7 @@ type ipLimiter struct { clients map[string]*bucket mu sync.Mutex logger *slog.Logger + trusted []*net.IPNet } func (l *ipLimiter) allow(ip string) bool { @@ -129,18 +146,48 @@ func (l *ipLimiter) cleanup(interval time.Duration) { } } -func extractIP(r *http.Request) string { - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - parts := strings.SplitN(xff, ",", 2) - return strings.TrimSpace(parts[0]) - } - if rip := r.Header.Get("X-Real-IP"); rip != "" { - return strings.TrimSpace(rip) +// extractIP extracts the client IP from the request. +// If trusted proxy CIDRs are configured, X-Forwarded-For is only used when +// the direct connection comes from a trusted proxy. Otherwise, only RemoteAddr is used. +func (l *ipLimiter) extractIP(r *http.Request) string { + return extractIP(r, l.trusted...) +} + +func extractIP(r *http.Request, trusted ...*net.IPNet) string { + remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + remoteIP = r.RemoteAddr } - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr + // Check if the direct connection is from a trusted proxy. + isTrusted := false + if len(trusted) > 0 { + ip := net.ParseIP(remoteIP) + if ip != nil { + for _, network := range trusted { + if network.Contains(ip) { + isTrusted = true + break + } + } + } } - return host + + if isTrusted { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + parts := strings.SplitN(xff, ",", 2) + candidate := strings.TrimSpace(parts[0]) + if net.ParseIP(candidate) != nil { + return candidate + } + } + if rip := r.Header.Get("X-Real-IP"); rip != "" { + candidate := strings.TrimSpace(rip) + if net.ParseIP(candidate) != nil { + return candidate + } + } + } + + return remoteIP } diff --git a/internal/middleware/ratelimit_global.go b/internal/middleware/ratelimit_global.go index 538c435..0bd34c5 100644 --- a/internal/middleware/ratelimit_global.go +++ b/internal/middleware/ratelimit_global.go @@ -71,7 +71,7 @@ func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.H w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n")) - logger.Warn("global rate limit exceeded", "ip", extractIP(r)) + logger.Warn("global rate limit exceeded", "remote", r.RemoteAddr) return } diff --git a/internal/middleware/ratelimit_test.go b/internal/middleware/ratelimit_test.go index 987d014..8366e57 100644 --- a/internal/middleware/ratelimit_test.go +++ b/internal/middleware/ratelimit_test.go @@ -1,6 +1,7 @@ package middleware import ( + "net" "net/http" "net/http/httptest" "testing" @@ -93,8 +94,9 @@ func TestRateLimit_DifferentIPs(t *testing.T) { func TestRateLimit_XForwardedFor(t *testing.T) { h := RateLimit(RateLimitConfig{ - Requests: 1, - Window: 10 * time.Second, + Requests: 1, + Window: 10 * time.Second, + TrustedProxies: []string{"10.0.0.0/8"}, }, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -143,17 +145,27 @@ func TestRateLimit_WindowExpires(t *testing.T) { } func TestExtractIP(t *testing.T) { + // Trusted proxy: loopback + loopback := mustParseCIDR("127.0.0.0/8") + privateNet := mustParseCIDR("10.0.0.0/8") + tests := []struct { name string xff string realIP string remote string + trusted []*net.IPNet expected string }{ - {"xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", "203.0.113.50"}, - {"real_ip", "", "203.0.113.50", "10.0.0.1:1234", "203.0.113.50"}, - {"remote", "", "", "1.2.3.4:5678", "1.2.3.4"}, - {"xff_over_real", "203.0.113.50", "10.0.0.1", "10.0.0.1:1234", "203.0.113.50"}, + // No trusted proxies → always use RemoteAddr. + {"no_trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", nil, "10.0.0.1"}, + {"no_trusted_real", "", "203.0.113.50", "10.0.0.1:1234", nil, "10.0.0.1"}, + {"no_trusted_remote", "", "", "1.2.3.4:5678", nil, "1.2.3.4"}, + // Trusted proxy → XFF is respected. + {"trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"}, + {"trusted_real_ip", "", "203.0.113.50", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"}, + // Untrusted remote → XFF ignored even if present. + {"untrusted_xff", "203.0.113.50, 10.0.0.1", "", "1.2.3.4:5678", []*net.IPNet{loopback}, "1.2.3.4"}, } for _, tt := range tests { @@ -167,9 +179,17 @@ func TestExtractIP(t *testing.T) { } req.RemoteAddr = tt.remote - if got := extractIP(req); got != tt.expected { + if got := extractIP(req, tt.trusted...); got != tt.expected { t.Errorf("extractIP() = %q, want %q", got, tt.expected) } }) } } + +func mustParseCIDR(s string) *net.IPNet { + _, network, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return network +} diff --git a/internal/middleware/security.go b/internal/middleware/security.go new file mode 100644 index 0000000..09f3878 --- /dev/null +++ b/internal/middleware/security.go @@ -0,0 +1,92 @@ +// kafka — a privacy-respecting metasearch engine +// Copyright (C) 2026-present metamorphosis-dev +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +package middleware + +import ( + "net/http" + "strconv" + "strings" +) + +// SecurityHeadersConfig controls which security headers are set. +type SecurityHeadersConfig struct { + // FrameOptions controls X-Frame-Options. Default: "DENY". + FrameOptions string + // HSTSMaxAge controls the max-age for Strict-Transport-Security. + // Set to 0 to disable HSTS (useful for local dev). Default: 31536000 (1 year). + HSTSMaxAge int + // HSTSPreloadDomains adds "includeSubDomains; preload" to HSTS. + HSTSPreloadDomains bool + // ReferrerPolicy controls the Referrer-Policy header. Default: "no-referrer". + ReferrerPolicy string + // CSP controls Content-Security-Policy. Default: a restrictive policy. + // Set to "" to disable CSP entirely. + CSP string +} + +// SecurityHeaders returns middleware that sets standard HTTP security headers +// on every response. +func SecurityHeaders(cfg SecurityHeadersConfig) func(http.Handler) http.Handler { + frameOpts := cfg.FrameOptions + if frameOpts == "" { + frameOpts = "DENY" + } + + hstsAge := cfg.HSTSMaxAge + if hstsAge == 0 { + hstsAge = 31536000 // 1 year + } + + refPol := cfg.ReferrerPolicy + if refPol == "" { + refPol = "no-referrer" + } + + csp := cfg.CSP + if csp == "" { + csp = defaultCSP() + } + + hstsValue := "max-age=" + strconv.Itoa(hstsAge) + if cfg.HSTSPreloadDomains { + hstsValue += "; includeSubDomains; preload" + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", frameOpts) + w.Header().Set("Referrer-Policy", refPol) + w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") + w.Header().Set("Content-Security-Policy", csp) + + if hstsAge > 0 { + w.Header().Set("Strict-Transport-Security", hstsValue) + } + + next.ServeHTTP(w, r) + }) + } +} + +// defaultCSP returns a restrictive Content-Security-Policy for the +// metasearch engine. +func defaultCSP() string { + return strings.Join([]string{ + "default-src 'self'", + "script-src 'self'", + "style-src 'self' 'unsafe-inline'", + "img-src 'self' https: data:", + "connect-src 'self'", + "font-src 'self'", + "frame-ancestors 'none'", + "base-uri 'self'", + "form-action 'self'", + }, "; ") +} diff --git a/internal/search/request_params.go b/internal/search/request_params.go index baad193..2e477fb 100644 --- a/internal/search/request_params.go +++ b/internal/search/request_params.go @@ -26,6 +26,28 @@ import ( var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`) +// maxQueryLength is the maximum allowed length for the search query. +const maxQueryLength = 1024 + +// knownEngineNames is the allowlist of valid engine identifiers. +var knownEngineNames = map[string]bool{ + "wikipedia": true, "arxiv": true, "crossref": true, + "braveapi": true, "brave": true, "qwant": true, + "duckduckgo": true, "github": true, "reddit": true, + "bing": true, "google": true, "youtube": true, +} + +// validateEngines filters engine names against the known registry. +func validateEngines(engines []string) []string { + out := make([]string, 0, len(engines)) + for _, e := range engines { + if knownEngineNames[strings.ToLower(e)] { + out = append(out, strings.ToLower(e)) + } + } + return out +} + func ParseSearchRequest(r *http.Request) (SearchRequest, error) { // Supports both GET and POST and relies on form values for routing. if err := r.ParseForm(); err != nil { @@ -50,6 +72,9 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) { if strings.TrimSpace(q) == "" { return SearchRequest{}, errors.New("missing required parameter: q") } + if len(q) > maxQueryLength { + return SearchRequest{}, errors.New("query exceeds maximum length") + } pageno := 1 if s := strings.TrimSpace(r.FormValue("pageno")); s != "" { @@ -105,6 +130,8 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) { // engines is an explicit list of engine names. engines := splitCSV(strings.TrimSpace(r.FormValue("engines"))) + // Validate engine names against known registry to prevent injection. + engines = validateEngines(engines) // categories and category_ params mirror the webadapter parsing. // We don't validate against a registry here; we just preserve the requested values. diff --git a/internal/upstream/client.go b/internal/upstream/client.go index 2bff509..27d74b6 100644 --- a/internal/upstream/client.go +++ b/internal/upstream/client.go @@ -44,6 +44,9 @@ func NewClient(baseURL string, timeout time.Duration) (*Client, error) { if err != nil { return nil, fmt.Errorf("invalid upstream base URL: %w", err) } + if u.Scheme != "http" && u.Scheme != "https" { + return nil, fmt.Errorf("upstream URL must use http or https, got %q", u.Scheme) + } // Normalize: trim trailing slash to make URL concatenation predictable. base := strings.TrimRight(u.String(), "/") @@ -108,7 +111,7 @@ func (c *Client) SearchJSON(ctx context.Context, req contracts.SearchRequest, en } if resp.StatusCode != http.StatusOK { - return contracts.SearchResponse{}, fmt.Errorf("upstream search failed: status=%d body=%q", resp.StatusCode, string(body)) + return contracts.SearchResponse{}, fmt.Errorf("upstream search failed with status %d", resp.StatusCode) } // Decode upstream JSON into our contract types. diff --git a/internal/util/validate.go b/internal/util/validate.go new file mode 100644 index 0000000..eb7d55e --- /dev/null +++ b/internal/util/validate.go @@ -0,0 +1,127 @@ +// kafka — a privacy-respecting metasearch engine +// Copyright (C) 2026-present metamorphosis-dev +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License,// or (at your option) any later version. + +package util + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +// SafeURLScheme validates that a URL is well-formed and uses an acceptable scheme. +// Returns the parsed URL on success, or an error. +func SafeURLScheme(raw string) (*url.URL, error) { + u, err := url.Parse(raw) + if err != nil { + return nil, err + } + if u.Scheme != "http" && u.Scheme != "https" { + return nil, fmt.Errorf("URL must use http or https, got %q", u.Scheme) + } + return u, nil +} + +// IsPrivateIP returns true if the IP address is in a private, loopback, +// link-local, or otherwise non-routable range. +func IsPrivateIP(host string) bool { + // Strip port if present. + h, _, err := net.SplitHostPort(host) + if err != nil { + h = host + } + + // Resolve hostname to IPs. + ips, err := net.LookupIP(h) + if err != nil || len(ips) == 0 { + // If we can't resolve, reject to be safe. + return true + } + + for _, ip := range ips { + if isPrivateIPAddr(ip) { + return true + } + } + return false +} + +func isPrivateIPAddr(ip net.IP) bool { + privateRanges := []struct { + network *net.IPNet + }{ + // Loopback + {mustParseCIDR("127.0.0.0/8")}, + {mustParseCIDR("::1/128")}, + // RFC 1918 + {mustParseCIDR("10.0.0.0/8")}, + {mustParseCIDR("172.16.0.0/12")}, + {mustParseCIDR("192.168.0.0/16")}, + // RFC 6598 (Carrier-grade NAT) + {mustParseCIDR("100.64.0.0/10")}, + // Link-local + {mustParseCIDR("169.254.0.0/16")}, + {mustParseCIDR("fe80::/10")}, + // IPv6 unique local + {mustParseCIDR("fc00::/7")}, + // IPv4-mapped IPv6 loopback + {mustParseCIDR("::ffff:127.0.0.0/104")}, + } + + for _, r := range privateRanges { + if r.network.Contains(ip) { + return true + } + } + return false +} + +func mustParseCIDR(s string) *net.IPNet { + _, network, err := net.ParseCIDR(s) + if err != nil { + panic(fmt.Sprintf("validate: invalid CIDR %q: %v", s, err)) + } + return network +} + +// ValidatePublicURL checks that a URL is well-formed, uses http or https, +// and does not point to a private/reserved IP range. +func ValidatePublicURL(raw string) error { + u, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("URL must use http or https, got %q", u.Scheme) + } + if u.Host == "" { + return fmt.Errorf("URL must have a host") + } + if IsPrivateIP(u.Host) { + return fmt.Errorf("URL points to a private or reserved address: %s", u.Host) + } + return nil +} + +// SanitizeResultURL ensures a URL is safe for rendering in an href attribute. +// It rejects javascript:, data:, vbscript: and other dangerous schemes. +func SanitizeResultURL(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + switch strings.ToLower(u.Scheme) { + case "http", "https", "": + return raw + default: + return "" + } +} diff --git a/internal/views/views.go b/internal/views/views.go index 6f23937..c176f81 100644 --- a/internal/views/views.go +++ b/internal/views/views.go @@ -18,6 +18,7 @@ package views import ( "embed" + "encoding/xml" "html/template" "io/fs" "net/http" @@ -25,6 +26,7 @@ import ( "strings" "github.com/metamorphosis-dev/kafka/internal/contracts" + "github.com/metamorphosis-dev/kafka/internal/util" ) //go:embed all:templates @@ -122,15 +124,20 @@ func StaticFS() (fs.FS, error) { return fs.Sub(staticFS, "static") } -// OpenSearchXML returns the OpenSearch description XML with {baseUrl} -// replaced by the provided base URL. +// OpenSearchXML returns the OpenSearch description XML with the base URL +// safely embedded via xml.EscapeText (no raw string interpolation). func OpenSearchXML(baseURL string) ([]byte, error) { tmplFS, _ := fs.Sub(templatesFS, "templates") data, err := fs.ReadFile(tmplFS, "opensearch.xml") if err != nil { return nil, err } - result := strings.ReplaceAll(string(data), "{baseUrl}", baseURL) + + var buf strings.Builder + xml.Escape(&buf, []byte(baseURL)) + escapedBaseURL := buf.String() + + result := strings.ReplaceAll(string(data), "{baseUrl}", escapedBaseURL) return []byte(result), nil } @@ -190,6 +197,12 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ if r.Template == "videos" { tmplName = "video_item" } + // Sanitize URLs to prevent javascript:/data: scheme injection. + if r.URL != nil { + safe := util.SanitizeResultURL(*r.URL) + r.URL = &safe + } + r.Thumbnail = util.SanitizeResultURL(r.Thumbnail) pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName} } @@ -213,7 +226,7 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ iv.Title = v } if v, ok := ib["img_src"].(string); ok { - iv.ImgSrc = v + iv.ImgSrc = util.SanitizeResultURL(v) } if iv.Title != "" || iv.Content != "" { pd.Infoboxes = append(pd.Infoboxes, iv)