security: harden against SAST findings (criticals through mediums)

Critical:
- Validate baseURL/sourceURL/upstreamURL at config load time
  (prevents XML injection, XSS, SSRF via config/env manipulation)
- Use xml.Escape for OpenSearch XML template interpolation

High:
- Add security headers middleware (CSP, X-Frame-Options, HSTS, etc.)
- Sanitize result URLs to reject javascript:/data: schemes
- Sanitize infobox img_src against dangerous URL schemes
- Default CORS to deny-all (was wildcard *)

Medium:
- Rate limiter: X-Forwarded-For only trusted from configured proxies
- Validate engine names against known registry allowlist
- Add 1024-char max query length
- Sanitize upstream error messages (strip raw response bodies)
- Upstream client validates URL scheme (http/https only)

Test updates:
- Update extractIP tests for new trusted proxy behavior
This commit is contained in:
Franz Kafka 2026-03-22 16:22:27 +00:00
parent 4b0cde91ed
commit da367a1bfd
23 changed files with 399 additions and 41 deletions

View file

@ -95,8 +95,9 @@ func main() {
var subFS fs.FS = staticFS
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler.
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → security headers → handler.
var handler http.Handler = mux
handler = middleware.SecurityHeaders(middleware.SecurityHeadersConfig{})(handler)
handler = middleware.CORS(middleware.CORSConfig{
AllowedOrigins: cfg.CORS.AllowedOrigins,
AllowedMethods: cfg.CORS.AllowedMethods,
@ -108,6 +109,7 @@ func main() {
Requests: cfg.RateLimit.Requests,
Window: cfg.RateLimitWindow(),
CleanupInterval: cfg.RateLimitCleanupInterval(),
TrustedProxies: cfg.RateLimit.TrustedProxies,
}, logger)(handler)
handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{
Requests: cfg.GlobalRateLimit.Requests,

View file

@ -18,11 +18,13 @@ package config
import (
"fmt"
"log"
"os"
"strings"
"time"
"github.com/BurntSushi/toml"
"github.com/metamorphosis-dev/kafka/internal/util"
)
// Config is the top-level configuration for the kafka service.
@ -77,6 +79,7 @@ type RateLimitConfig struct {
Requests int `toml:"requests"` // Max requests per window (default: 30)
Window string `toml:"window"` // Time window (e.g. "1m", default: "1m")
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
TrustedProxies []string `toml:"trusted_proxies"` // CIDRs allowed to set X-Forwarded-For
}
// GlobalRateLimitConfig holds server-wide rate limiting settings.
@ -120,9 +123,35 @@ func Load(path string) (*Config, error) {
}
applyEnvOverrides(cfg)
if err := validateConfig(cfg); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
return cfg, nil
}
// validateConfig checks security-critical config values at startup.
func validateConfig(cfg *Config) error {
if cfg.Server.BaseURL != "" {
if err := util.ValidatePublicURL(cfg.Server.BaseURL); err != nil {
return fmt.Errorf("server.base_url: %w", err)
}
}
if cfg.Server.SourceURL != "" {
if err := util.ValidatePublicURL(cfg.Server.SourceURL); err != nil {
return fmt.Errorf("server.source_url: %w", err)
}
}
if cfg.Upstream.URL != "" {
if err := util.ValidatePublicURL(cfg.Upstream.URL); err != nil {
return fmt.Errorf("upstream.url: %w", err)
}
log.Printf("WARNING: upstream.url SSRF protection is enabled; ensure the upstream host is not on a private network")
}
return nil
}
func defaultConfig() *Config {
return &Config{
Server: ServerConfig{

View file

@ -76,7 +76,7 @@ func (e *ArxivEngine) Search(ctx context.Context, req contracts.SearchRequest) (
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status %d", resp.StatusCode)
}
raw, err := io.ReadAll(resp.Body)

View file

@ -69,7 +69,7 @@ func (e *BingEngine) Search(ctx context.Context, req contracts.SearchRequest) (c
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")

View file

@ -46,7 +46,7 @@ func (e *BraveEngine) Search(ctx context.Context, req contracts.SearchRequest) (
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("brave error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("brave error: status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))

View file

@ -128,7 +128,7 @@ func (e *BraveAPIEngine) Search(ctx context.Context, req contracts.SearchRequest
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status %d", resp.StatusCode)
}
var api struct {

View file

@ -64,7 +64,7 @@ func (e *CrossrefEngine) Search(ctx context.Context, req contracts.SearchRequest
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status %d", resp.StatusCode)
}
var api struct {

View file

@ -64,7 +64,7 @@ func (e *DuckDuckGoEngine) Search(ctx context.Context, req contracts.SearchReque
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status %d", resp.StatusCode)
}
results, err := parseDuckDuckGoHTML(resp.Body)

View file

@ -67,7 +67,7 @@ func (e *GitHubEngine) Search(ctx context.Context, req contracts.SearchRequest)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("github api error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("github api error: status %d", resp.StatusCode)
}
var data struct {

View file

@ -96,7 +96,7 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("google error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("google error: status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))

View file

@ -125,7 +125,7 @@ func (e *QwantEngine) searchWebAPI(ctx context.Context, req contracts.SearchRequ
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
@ -254,7 +254,7 @@ func (e *QwantEngine) searchWebLite(ctx context.Context, req contracts.SearchReq
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status %d", resp.StatusCode)
}
doc, err := goquery.NewDocumentFromReader(resp.Body)

View file

@ -63,7 +63,7 @@ func (e *RedditEngine) Search(ctx context.Context, req contracts.SearchRequest)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status %d", resp.StatusCode)
}
var data struct {

View file

@ -135,7 +135,7 @@ func (e *WikipediaEngine) Search(ctx context.Context, req contracts.SearchReques
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status %d", resp.StatusCode)
}
var api struct {

View file

@ -78,7 +78,7 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status %d", resp.StatusCode)
}
var apiResp youtubeSearchResponse
@ -87,7 +87,7 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
}
if apiResp.Error != nil {
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: %s", apiResp.Error.Message)
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: code %d", apiResp.Error.Code)
}
results := make([]contracts.MainResult, 0, len(apiResp.Items))

View file

@ -42,7 +42,8 @@ type CORSConfig struct {
func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
origins := cfg.AllowedOrigins
if len(origins) == 0 {
origins = []string{"*"}
// Default: no CORS headers. Explicitly configure origins to enable.
origins = nil
}
methods := cfg.AllowedMethods
@ -70,6 +71,7 @@ func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
origin := r.Header.Get("Origin")
// Determine the allowed origin for this request.
// If no origins are configured, CORS is disabled entirely — no headers are set.
allowedOrigin := ""
for _, o := range origins {
if o == "*" {

View file

@ -27,10 +27,14 @@ import (
"log/slog"
)
// RateLimitConfig controls per-IP rate limiting.
type RateLimitConfig struct {
Requests int
Window time.Duration
CleanupInterval time.Duration
// TrustedProxies is a list of CIDR ranges that are allowed to set
// X-Forwarded-For / X-Real-IP. If empty, only r.RemoteAddr is used.
TrustedProxies []string
}
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
@ -53,18 +57,30 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http
logger = slog.Default()
}
// Parse trusted proxy CIDRs.
var trustedNets []*net.IPNet
for _, cidr := range cfg.TrustedProxies {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
logger.Warn("invalid trusted proxy CIDR, skipping", "cidr", cidr, "error", err)
continue
}
trustedNets = append(trustedNets, network)
}
limiter := &ipLimiter{
requests: requests,
window: window,
clients: make(map[string]*bucket),
logger: logger,
trusted: trustedNets,
}
go limiter.cleanup(cleanup)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractIP(r)
ip := l.extractIP(r)
if !limiter.allow(ip) {
retryAfter := int(limiter.window.Seconds())
@ -92,6 +108,7 @@ type ipLimiter struct {
clients map[string]*bucket
mu sync.Mutex
logger *slog.Logger
trusted []*net.IPNet
}
func (l *ipLimiter) allow(ip string) bool {
@ -129,18 +146,48 @@ func (l *ipLimiter) cleanup(interval time.Duration) {
}
}
func extractIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
return strings.TrimSpace(parts[0])
}
if rip := r.Header.Get("X-Real-IP"); rip != "" {
return strings.TrimSpace(rip)
// extractIP extracts the client IP from the request.
// If trusted proxy CIDRs are configured, X-Forwarded-For is only used when
// the direct connection comes from a trusted proxy. Otherwise, only RemoteAddr is used.
func (l *ipLimiter) extractIP(r *http.Request) string {
return extractIP(r, l.trusted...)
}
func extractIP(r *http.Request, trusted ...*net.IPNet) string {
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteIP = r.RemoteAddr
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
// Check if the direct connection is from a trusted proxy.
isTrusted := false
if len(trusted) > 0 {
ip := net.ParseIP(remoteIP)
if ip != nil {
for _, network := range trusted {
if network.Contains(ip) {
isTrusted = true
break
}
}
}
}
return host
if isTrusted {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
candidate := strings.TrimSpace(parts[0])
if net.ParseIP(candidate) != nil {
return candidate
}
}
if rip := r.Header.Get("X-Real-IP"); rip != "" {
candidate := strings.TrimSpace(rip)
if net.ParseIP(candidate) != nil {
return candidate
}
}
}
return remoteIP
}

View file

@ -71,7 +71,7 @@ func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.H
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
logger.Warn("global rate limit exceeded", "ip", extractIP(r))
logger.Warn("global rate limit exceeded", "remote", r.RemoteAddr)
return
}

View file

@ -92,9 +92,11 @@ func TestRateLimit_DifferentIPs(t *testing.T) {
}
func TestRateLimit_XForwardedFor(t *testing.T) {
privateNet := mustParseCIDR("10.0.0.0/8")
h := RateLimit(RateLimitConfig{
Requests: 1,
Window: 10 * time.Second,
Requests: 1,
Window: 10 * time.Second,
TrustedProxies: []string{"10.0.0.0/8"},
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
@ -143,17 +145,27 @@ func TestRateLimit_WindowExpires(t *testing.T) {
}
func TestExtractIP(t *testing.T) {
// Trusted proxy: loopback
loopback := mustParseCIDR("127.0.0.0/8")
privateNet := mustParseCIDR("10.0.0.0/8")
tests := []struct {
name string
xff string
realIP string
remote string
trusted []*net.IPNet
expected string
}{
{"xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", "203.0.113.50"},
{"real_ip", "", "203.0.113.50", "10.0.0.1:1234", "203.0.113.50"},
{"remote", "", "", "1.2.3.4:5678", "1.2.3.4"},
{"xff_over_real", "203.0.113.50", "10.0.0.1", "10.0.0.1:1234", "203.0.113.50"},
// No trusted proxies → always use RemoteAddr.
{"no_trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", nil, "10.0.0.1"},
{"no_trusted_real", "", "203.0.113.50", "10.0.0.1:1234", nil, "10.0.0.1"},
{"no_trusted_remote", "", "", "1.2.3.4:5678", nil, "1.2.3.4"},
// Trusted proxy → XFF is respected.
{"trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
{"trusted_real_ip", "", "203.0.113.50", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
// Untrusted remote → XFF ignored even if present.
{"untrusted_xff", "203.0.113.50, 10.0.0.1", "", "1.2.3.4:5678", []*net.IPNet{loopback}, "1.2.3.4"},
}
for _, tt := range tests {
@ -167,9 +179,17 @@ func TestExtractIP(t *testing.T) {
}
req.RemoteAddr = tt.remote
if got := extractIP(req); got != tt.expected {
if got := extractIP(req, tt.trusted...); got != tt.expected {
t.Errorf("extractIP() = %q, want %q", got, tt.expected)
}
})
}
}
func mustParseCIDR(s string) *net.IPNet {
_, network, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return network
}

View file

@ -0,0 +1,92 @@
// kafka — a privacy-respecting metasearch engine
// Copyright (C) 2026-present metamorphosis-dev
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
package middleware
import (
"net/http"
"strconv"
"strings"
)
// SecurityHeadersConfig controls which security headers are set.
type SecurityHeadersConfig struct {
// FrameOptions controls X-Frame-Options. Default: "DENY".
FrameOptions string
// HSTSMaxAge controls the max-age for Strict-Transport-Security.
// Set to 0 to disable HSTS (useful for local dev). Default: 31536000 (1 year).
HSTSMaxAge int
// HSTSPreloadDomains adds "includeSubDomains; preload" to HSTS.
HSTSPreloadDomains bool
// ReferrerPolicy controls the Referrer-Policy header. Default: "no-referrer".
ReferrerPolicy string
// CSP controls Content-Security-Policy. Default: a restrictive policy.
// Set to "" to disable CSP entirely.
CSP string
}
// SecurityHeaders returns middleware that sets standard HTTP security headers
// on every response.
func SecurityHeaders(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
frameOpts := cfg.FrameOptions
if frameOpts == "" {
frameOpts = "DENY"
}
hstsAge := cfg.HSTSMaxAge
if hstsAge == 0 {
hstsAge = 31536000 // 1 year
}
refPol := cfg.ReferrerPolicy
if refPol == "" {
refPol = "no-referrer"
}
csp := cfg.CSP
if csp == "" {
csp = defaultCSP()
}
hstsValue := "max-age=" + strconv.Itoa(hstsAge)
if cfg.HSTSPreloadDomains {
hstsValue += "; includeSubDomains; preload"
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", frameOpts)
w.Header().Set("Referrer-Policy", refPol)
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Content-Security-Policy", csp)
if hstsAge > 0 {
w.Header().Set("Strict-Transport-Security", hstsValue)
}
next.ServeHTTP(w, r)
})
}
}
// defaultCSP returns a restrictive Content-Security-Policy for the
// metasearch engine.
func defaultCSP() string {
return strings.Join([]string{
"default-src 'self'",
"script-src 'self'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' https: data:",
"connect-src 'self'",
"font-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
}, "; ")
}

View file

@ -26,6 +26,28 @@ import (
var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`)
// maxQueryLength is the maximum allowed length for the search query.
const maxQueryLength = 1024
// knownEngineNames is the allowlist of valid engine identifiers.
var knownEngineNames = map[string]bool{
"wikipedia": true, "arxiv": true, "crossref": true,
"braveapi": true, "brave": true, "qwant": true,
"duckduckgo": true, "github": true, "reddit": true,
"bing": true, "google": true, "youtube": true,
}
// validateEngines filters engine names against the known registry.
func validateEngines(engines []string) []string {
out := make([]string, 0, len(engines))
for _, e := range engines {
if knownEngineNames[strings.ToLower(e)] {
out = append(out, strings.ToLower(e))
}
}
return out
}
func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
// Supports both GET and POST and relies on form values for routing.
if err := r.ParseForm(); err != nil {
@ -50,6 +72,9 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
if strings.TrimSpace(q) == "" {
return SearchRequest{}, errors.New("missing required parameter: q")
}
if len(q) > maxQueryLength {
return SearchRequest{}, errors.New("query exceeds maximum length")
}
pageno := 1
if s := strings.TrimSpace(r.FormValue("pageno")); s != "" {
@ -105,6 +130,8 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
// engines is an explicit list of engine names.
engines := splitCSV(strings.TrimSpace(r.FormValue("engines")))
// Validate engine names against known registry to prevent injection.
engines = validateEngines(engines)
// categories and category_<name> params mirror the webadapter parsing.
// We don't validate against a registry here; we just preserve the requested values.

View file

@ -44,6 +44,9 @@ func NewClient(baseURL string, timeout time.Duration) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("invalid upstream base URL: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return nil, fmt.Errorf("upstream URL must use http or https, got %q", u.Scheme)
}
// Normalize: trim trailing slash to make URL concatenation predictable.
base := strings.TrimRight(u.String(), "/")
@ -108,7 +111,7 @@ func (c *Client) SearchJSON(ctx context.Context, req contracts.SearchRequest, en
}
if resp.StatusCode != http.StatusOK {
return contracts.SearchResponse{}, fmt.Errorf("upstream search failed: status=%d body=%q", resp.StatusCode, string(body))
return contracts.SearchResponse{}, fmt.Errorf("upstream search failed with status %d", resp.StatusCode)
}
// Decode upstream JSON into our contract types.

123
internal/util/validate.go Normal file
View file

@ -0,0 +1,123 @@
// kafka — a privacy-respecting metasearch engine
// Copyright (C) 2026-present metamorphosis-dev
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License,// or (at your option) any later version.
package util
import (
"fmt"
"net"
"net/url"
"strings"
)
// SafeURLScheme returns true if the URL uses an acceptable scheme (http or https).
func SafeURLScheme(raw string) bool {
u, err := url.Parse(raw)
if err != nil {
return false
}
return u.Scheme == "http" || u.Scheme == "https"
}
// IsPrivateIP returns true if the IP address is in a private, loopback,
// link-local, or otherwise non-routable range.
func IsPrivateIP(host string) bool {
// Strip port if present.
h, _, err := net.SplitHostPort(host)
if err != nil {
h = host
}
// Resolve hostname to IPs.
ips, err := net.LookupIP(h)
if err != nil || len(ips) == 0 {
// If we can't resolve, reject to be safe.
return true
}
for _, ip := range ips {
if isPrivateIPAddr(ip) {
return true
}
}
return false
}
func isPrivateIPAddr(ip net.IP) bool {
privateRanges := []struct {
network *net.IPNet
}{
// Loopback
{mustParseCIDR("127.0.0.0/8")},
{mustParseCIDR("::1/128")},
// RFC 1918
{mustParseCIDR("10.0.0.0/8")},
{mustParseCIDR("172.16.0.0/12")},
{mustParseCIDR("192.168.0.0/16")},
// RFC 6598 (Carrier-grade NAT)
{mustParseCIDR("100.64.0.0/10")},
// Link-local
{mustParseCIDR("169.254.0.0/16")},
{mustParseCIDR("fe80::/10")},
// IPv6 unique local
{mustParseCIDR("fc00::/7")},
// IPv4-mapped IPv6 loopback
{mustParseCIDR("::ffff:127.0.0.0/104")},
}
for _, r := range privateRanges {
if r.network.Contains(ip) {
return true
}
}
return false
}
func mustParseCIDR(s string) *net.IPNet {
_, network, err := net.ParseCIDR(s)
if err != nil {
panic(fmt.Sprintf("validate: invalid CIDR %q: %v", s, err))
}
return network
}
// ValidatePublicURL checks that a URL is well-formed, uses http or https,
// and does not point to a private/reserved IP range.
func ValidatePublicURL(raw string) error {
u, err := url.Parse(raw)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("URL must use http or https, got %q", u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("URL must have a host")
}
if IsPrivateIP(u.Host) {
return fmt.Errorf("URL points to a private or reserved address: %s", u.Host)
}
return nil
}
// SanitizeResultURL ensures a URL is safe for rendering in an href attribute.
// It rejects javascript:, data:, vbscript: and other dangerous schemes.
func SanitizeResultURL(raw string) string {
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return ""
}
switch strings.ToLower(u.Scheme) {
case "http", "https", "":
return raw
default:
return ""
}
}

View file

@ -18,6 +18,7 @@ package views
import (
"embed"
"encoding/xml"
"html/template"
"io/fs"
"net/http"
@ -25,6 +26,7 @@ import (
"strings"
"github.com/metamorphosis-dev/kafka/internal/contracts"
"github.com/metamorphosis-dev/kafka/internal/util"
)
//go:embed all:templates
@ -122,15 +124,20 @@ func StaticFS() (fs.FS, error) {
return fs.Sub(staticFS, "static")
}
// OpenSearchXML returns the OpenSearch description XML with {baseUrl}
// replaced by the provided base URL.
// OpenSearchXML returns the OpenSearch description XML with the base URL
// safely embedded via xml.EscapeText (no raw string interpolation).
func OpenSearchXML(baseURL string) ([]byte, error) {
tmplFS, _ := fs.Sub(templatesFS, "templates")
data, err := fs.ReadFile(tmplFS, "opensearch.xml")
if err != nil {
return nil, err
}
result := strings.ReplaceAll(string(data), "{baseUrl}", baseURL)
var buf strings.Builder
xml.Escape(&buf, []byte(baseURL))
escapedBaseURL := buf.String()
result := strings.ReplaceAll(string(data), "{baseUrl}", escapedBaseURL)
return []byte(result), nil
}
@ -190,6 +197,12 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
if r.Template == "videos" {
tmplName = "video_item"
}
// Sanitize URLs to prevent javascript:/data: scheme injection.
if r.URL != nil {
safe := util.SanitizeResultURL(*r.URL)
r.URL = &safe
}
r.Thumbnail = util.SanitizeResultURL(r.Thumbnail)
pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName}
}
@ -213,7 +226,7 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
iv.Title = v
}
if v, ok := ib["img_src"].(string); ok {
iv.ImgSrc = v
iv.ImgSrc = util.SanitizeResultURL(v)
}
if iv.Title != "" || iv.Content != "" {
pd.Infoboxes = append(pd.Infoboxes, iv)