Merge branch 'security/hardening-sast-fixes'
This commit is contained in:
commit
5884c080fd
24 changed files with 422 additions and 70 deletions
|
|
@ -95,8 +95,9 @@ func main() {
|
||||||
var subFS fs.FS = staticFS
|
var subFS fs.FS = staticFS
|
||||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
|
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(subFS))))
|
||||||
|
|
||||||
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → handler.
|
// Apply middleware: global rate limit → burst rate limit → per-IP rate limit → CORS → security headers → handler.
|
||||||
var handler http.Handler = mux
|
var handler http.Handler = mux
|
||||||
|
handler = middleware.SecurityHeaders(middleware.SecurityHeadersConfig{})(handler)
|
||||||
handler = middleware.CORS(middleware.CORSConfig{
|
handler = middleware.CORS(middleware.CORSConfig{
|
||||||
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
||||||
AllowedMethods: cfg.CORS.AllowedMethods,
|
AllowedMethods: cfg.CORS.AllowedMethods,
|
||||||
|
|
@ -108,6 +109,7 @@ func main() {
|
||||||
Requests: cfg.RateLimit.Requests,
|
Requests: cfg.RateLimit.Requests,
|
||||||
Window: cfg.RateLimitWindow(),
|
Window: cfg.RateLimitWindow(),
|
||||||
CleanupInterval: cfg.RateLimitCleanupInterval(),
|
CleanupInterval: cfg.RateLimitCleanupInterval(),
|
||||||
|
TrustedProxies: cfg.RateLimit.TrustedProxies,
|
||||||
}, logger)(handler)
|
}, logger)(handler)
|
||||||
handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{
|
handler = middleware.GlobalRateLimit(middleware.GlobalRateLimitConfig{
|
||||||
Requests: cfg.GlobalRateLimit.Requests,
|
Requests: cfg.GlobalRateLimit.Requests,
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
|
"github.com/metamorphosis-dev/kafka/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the top-level configuration for the kafka service.
|
// Config is the top-level configuration for the kafka service.
|
||||||
|
|
@ -77,6 +78,7 @@ type RateLimitConfig struct {
|
||||||
Requests int `toml:"requests"` // Max requests per window (default: 30)
|
Requests int `toml:"requests"` // Max requests per window (default: 30)
|
||||||
Window string `toml:"window"` // Time window (e.g. "1m", default: "1m")
|
Window string `toml:"window"` // Time window (e.g. "1m", default: "1m")
|
||||||
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
|
CleanupInterval string `toml:"cleanup_interval"` // Stale entry cleanup interval (default: "5m")
|
||||||
|
TrustedProxies []string `toml:"trusted_proxies"` // CIDRs allowed to set X-Forwarded-For
|
||||||
}
|
}
|
||||||
|
|
||||||
// GlobalRateLimitConfig holds server-wide rate limiting settings.
|
// GlobalRateLimitConfig holds server-wide rate limiting settings.
|
||||||
|
|
@ -120,9 +122,36 @@ func Load(path string) (*Config, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
applyEnvOverrides(cfg)
|
applyEnvOverrides(cfg)
|
||||||
|
|
||||||
|
if err := validateConfig(cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateConfig checks security-critical config values at startup.
|
||||||
|
func validateConfig(cfg *Config) error {
|
||||||
|
if cfg.Server.BaseURL != "" {
|
||||||
|
if err := util.ValidatePublicURL(cfg.Server.BaseURL); err != nil {
|
||||||
|
return fmt.Errorf("server.base_url: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.Server.SourceURL != "" {
|
||||||
|
if err := util.ValidatePublicURL(cfg.Server.SourceURL); err != nil {
|
||||||
|
return fmt.Errorf("server.source_url: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.Upstream.URL != "" {
|
||||||
|
// Validate scheme and well-formedness, but allow private IPs
|
||||||
|
// since self-hosted deployments commonly use localhost/internal addresses.
|
||||||
|
if _, err := util.SafeURLScheme(cfg.Upstream.URL); err != nil {
|
||||||
|
return fmt.Errorf("upstream.url: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func defaultConfig() *Config {
|
func defaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
|
|
|
||||||
|
|
@ -75,8 +75,8 @@ func (e *ArxivEngine) Search(ctx context.Context, req contracts.SearchRequest) (
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("arxiv upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
raw, err := io.ReadAll(resp.Body)
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
|
|
||||||
|
|
@ -68,8 +68,8 @@ func (e *BingEngine) Search(ctx context.Context, req contracts.SearchRequest) (c
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("bing upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,8 @@ func (e *BraveEngine) Search(ctx context.Context, req contracts.SearchRequest) (
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("brave error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("brave error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
||||||
|
|
|
||||||
|
|
@ -127,8 +127,8 @@ func (e *BraveAPIEngine) Search(ctx context.Context, req contracts.SearchRequest
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("brave upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var api struct {
|
var api struct {
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ func TestBraveEngine_GatingAndHeader(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
client := &http.Client{Transport: transport}
|
client := &http.Client{Transport: transport}
|
||||||
engine := &BraveEngine{
|
engine := &BraveAPIEngine{
|
||||||
client: client,
|
client: client,
|
||||||
apiKey: wantAPIKey,
|
apiKey: wantAPIKey,
|
||||||
accessGateToken: wantToken,
|
accessGateToken: wantToken,
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ func (e *CrossrefEngine) Search(ctx context.Context, req contracts.SearchRequest
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("crossref upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var api struct {
|
var api struct {
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ func (e *DuckDuckGoEngine) Search(ctx context.Context, req contracts.SearchReque
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("duckduckgo upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := parseDuckDuckGoHTML(resp.Body)
|
results, err := parseDuckDuckGoHTML(resp.Body)
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,8 @@ func (e *GitHubEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("github api error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("github api error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
|
|
|
||||||
|
|
@ -28,20 +28,10 @@ import (
|
||||||
"github.com/metamorphosis-dev/kafka/internal/contracts"
|
"github.com/metamorphosis-dev/kafka/internal/contracts"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GSA User-Agent pool — these are Google Search Appliance identifiers
|
// googleUserAgent is an honest User-Agent identifying the metasearch engine.
|
||||||
// that Google trusts for enterprise search appliance traffic.
|
// Using a spoofed GSA User-Agent violates Google's Terms of Service and
|
||||||
var gsaUserAgents = []string{
|
// risks permanent IP blocking.
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_5_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1",
|
var googleUserAgent = "Kafka/0.1 (compatible; +https://github.com/metamorphosis-dev/kafka)"
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_6_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
|
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
|
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_0_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
|
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/399.2.845414227 Mobile/15E148 Safari/604.1",
|
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 18_5_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) GSA/406.0.862495628 Mobile/15E148 Safari/604.1",
|
|
||||||
}
|
|
||||||
|
|
||||||
func gsaUA() string {
|
|
||||||
return gsaUserAgents[0] // deterministic for now; could rotate
|
|
||||||
}
|
|
||||||
|
|
||||||
type GoogleEngine struct {
|
type GoogleEngine struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
|
@ -70,7 +60,7 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return contracts.SearchResponse{}, err
|
return contracts.SearchResponse{}, err
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("User-Agent", gsaUA())
|
httpReq.Header.Set("User-Agent", googleUserAgent)
|
||||||
httpReq.Header.Set("Accept", "*/*")
|
httpReq.Header.Set("Accept", "*/*")
|
||||||
httpReq.AddCookie(&http.Cookie{Name: "CONSENT", Value: "YES+"})
|
httpReq.AddCookie(&http.Cookie{Name: "CONSENT", Value: "YES+"})
|
||||||
|
|
||||||
|
|
@ -95,8 +85,8 @@ func (e *GoogleEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("google error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("google error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
||||||
|
|
|
||||||
|
|
@ -124,8 +124,8 @@ func (e *QwantEngine) searchWebAPI(ctx context.Context, req contracts.SearchRequ
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("qwant upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||||
|
|
@ -253,8 +253,8 @@ func (e *QwantEngine) searchWebLite(ctx context.Context, req contracts.SearchReq
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("qwant lite upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
doc, err := goquery.NewDocumentFromReader(resp.Body)
|
doc, err := goquery.NewDocumentFromReader(resp.Body)
|
||||||
|
|
|
||||||
|
|
@ -62,8 +62,8 @@ func (e *RedditEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("reddit api error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
|
|
|
||||||
|
|
@ -134,8 +134,8 @@ func (e *WikipediaEngine) Search(ctx context.Context, req contracts.SearchReques
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 16*1024))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("wikipedia upstream error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var api struct {
|
var api struct {
|
||||||
|
|
|
||||||
|
|
@ -77,8 +77,8 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiResp youtubeSearchResponse
|
var apiResp youtubeSearchResponse
|
||||||
|
|
@ -87,7 +87,7 @@ func (e *YouTubeEngine) Search(ctx context.Context, req contracts.SearchRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiResp.Error != nil {
|
if apiResp.Error != nil {
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: %s", apiResp.Error.Message)
|
return contracts.SearchResponse{}, fmt.Errorf("youtube api error: code %d", apiResp.Error.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
results := make([]contracts.MainResult, 0, len(apiResp.Items))
|
results := make([]contracts.MainResult, 0, len(apiResp.Items))
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,8 @@ type CORSConfig struct {
|
||||||
func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
||||||
origins := cfg.AllowedOrigins
|
origins := cfg.AllowedOrigins
|
||||||
if len(origins) == 0 {
|
if len(origins) == 0 {
|
||||||
origins = []string{"*"}
|
// Default: no CORS headers. Explicitly configure origins to enable.
|
||||||
|
origins = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
methods := cfg.AllowedMethods
|
methods := cfg.AllowedMethods
|
||||||
|
|
@ -70,6 +71,7 @@ func CORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
// Determine the allowed origin for this request.
|
// Determine the allowed origin for this request.
|
||||||
|
// If no origins are configured, CORS is disabled entirely — no headers are set.
|
||||||
allowedOrigin := ""
|
allowedOrigin := ""
|
||||||
for _, o := range origins {
|
for _, o := range origins {
|
||||||
if o == "*" {
|
if o == "*" {
|
||||||
|
|
|
||||||
|
|
@ -27,10 +27,14 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RateLimitConfig controls per-IP rate limiting.
|
||||||
type RateLimitConfig struct {
|
type RateLimitConfig struct {
|
||||||
Requests int
|
Requests int
|
||||||
Window time.Duration
|
Window time.Duration
|
||||||
CleanupInterval time.Duration
|
CleanupInterval time.Duration
|
||||||
|
// TrustedProxies is a list of CIDR ranges that are allowed to set
|
||||||
|
// X-Forwarded-For / X-Real-IP. If empty, only r.RemoteAddr is used.
|
||||||
|
TrustedProxies []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||||
|
|
@ -53,18 +57,30 @@ func RateLimit(cfg RateLimitConfig, logger *slog.Logger) func(http.Handler) http
|
||||||
logger = slog.Default()
|
logger = slog.Default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse trusted proxy CIDRs.
|
||||||
|
var trustedNets []*net.IPNet
|
||||||
|
for _, cidr := range cfg.TrustedProxies {
|
||||||
|
_, network, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("invalid trusted proxy CIDR, skipping", "cidr", cidr, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
trustedNets = append(trustedNets, network)
|
||||||
|
}
|
||||||
|
|
||||||
limiter := &ipLimiter{
|
limiter := &ipLimiter{
|
||||||
requests: requests,
|
requests: requests,
|
||||||
window: window,
|
window: window,
|
||||||
clients: make(map[string]*bucket),
|
clients: make(map[string]*bucket),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
trusted: trustedNets,
|
||||||
}
|
}
|
||||||
|
|
||||||
go limiter.cleanup(cleanup)
|
go limiter.cleanup(cleanup)
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ip := extractIP(r)
|
ip := limiter.extractIP(r)
|
||||||
|
|
||||||
if !limiter.allow(ip) {
|
if !limiter.allow(ip) {
|
||||||
retryAfter := int(limiter.window.Seconds())
|
retryAfter := int(limiter.window.Seconds())
|
||||||
|
|
@ -92,6 +108,7 @@ type ipLimiter struct {
|
||||||
clients map[string]*bucket
|
clients map[string]*bucket
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
trusted []*net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *ipLimiter) allow(ip string) bool {
|
func (l *ipLimiter) allow(ip string) bool {
|
||||||
|
|
@ -129,18 +146,48 @@ func (l *ipLimiter) cleanup(interval time.Duration) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractIP(r *http.Request) string {
|
// extractIP extracts the client IP from the request.
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
// If trusted proxy CIDRs are configured, X-Forwarded-For is only used when
|
||||||
parts := strings.SplitN(xff, ",", 2)
|
// the direct connection comes from a trusted proxy. Otherwise, only RemoteAddr is used.
|
||||||
return strings.TrimSpace(parts[0])
|
func (l *ipLimiter) extractIP(r *http.Request) string {
|
||||||
}
|
return extractIP(r, l.trusted...)
|
||||||
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
}
|
||||||
return strings.TrimSpace(rip)
|
|
||||||
|
func extractIP(r *http.Request, trusted ...*net.IPNet) string {
|
||||||
|
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
remoteIP = r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
// Check if the direct connection is from a trusted proxy.
|
||||||
if err != nil {
|
isTrusted := false
|
||||||
return r.RemoteAddr
|
if len(trusted) > 0 {
|
||||||
|
ip := net.ParseIP(remoteIP)
|
||||||
|
if ip != nil {
|
||||||
|
for _, network := range trusted {
|
||||||
|
if network.Contains(ip) {
|
||||||
|
isTrusted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return host
|
|
||||||
|
if isTrusted {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
parts := strings.SplitN(xff, ",", 2)
|
||||||
|
candidate := strings.TrimSpace(parts[0])
|
||||||
|
if net.ParseIP(candidate) != nil {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rip := r.Header.Get("X-Real-IP"); rip != "" {
|
||||||
|
candidate := strings.TrimSpace(rip)
|
||||||
|
if net.ParseIP(candidate) != nil {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return remoteIP
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ func GlobalRateLimit(cfg GlobalRateLimitConfig, logger *slog.Logger) func(http.H
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
|
_, _ = w.Write([]byte("503 Service Unavailable — global rate limit exceeded\n"))
|
||||||
logger.Warn("global rate limit exceeded", "ip", extractIP(r))
|
logger.Warn("global rate limit exceeded", "remote", r.RemoteAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -93,8 +94,9 @@ func TestRateLimit_DifferentIPs(t *testing.T) {
|
||||||
|
|
||||||
func TestRateLimit_XForwardedFor(t *testing.T) {
|
func TestRateLimit_XForwardedFor(t *testing.T) {
|
||||||
h := RateLimit(RateLimitConfig{
|
h := RateLimit(RateLimitConfig{
|
||||||
Requests: 1,
|
Requests: 1,
|
||||||
Window: 10 * time.Second,
|
Window: 10 * time.Second,
|
||||||
|
TrustedProxies: []string{"10.0.0.0/8"},
|
||||||
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
}, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
|
|
@ -143,17 +145,27 @@ func TestRateLimit_WindowExpires(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractIP(t *testing.T) {
|
func TestExtractIP(t *testing.T) {
|
||||||
|
// Trusted proxy: loopback
|
||||||
|
loopback := mustParseCIDR("127.0.0.0/8")
|
||||||
|
privateNet := mustParseCIDR("10.0.0.0/8")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
xff string
|
xff string
|
||||||
realIP string
|
realIP string
|
||||||
remote string
|
remote string
|
||||||
|
trusted []*net.IPNet
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", "203.0.113.50"},
|
// No trusted proxies → always use RemoteAddr.
|
||||||
{"real_ip", "", "203.0.113.50", "10.0.0.1:1234", "203.0.113.50"},
|
{"no_trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", nil, "10.0.0.1"},
|
||||||
{"remote", "", "", "1.2.3.4:5678", "1.2.3.4"},
|
{"no_trusted_real", "", "203.0.113.50", "10.0.0.1:1234", nil, "10.0.0.1"},
|
||||||
{"xff_over_real", "203.0.113.50", "10.0.0.1", "10.0.0.1:1234", "203.0.113.50"},
|
{"no_trusted_remote", "", "", "1.2.3.4:5678", nil, "1.2.3.4"},
|
||||||
|
// Trusted proxy → XFF is respected.
|
||||||
|
{"trusted_xff", "203.0.113.50, 10.0.0.1", "", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
|
||||||
|
{"trusted_real_ip", "", "203.0.113.50", "10.0.0.1:1234", []*net.IPNet{privateNet}, "203.0.113.50"},
|
||||||
|
// Untrusted remote → XFF ignored even if present.
|
||||||
|
{"untrusted_xff", "203.0.113.50, 10.0.0.1", "", "1.2.3.4:5678", []*net.IPNet{loopback}, "1.2.3.4"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
@ -167,9 +179,17 @@ func TestExtractIP(t *testing.T) {
|
||||||
}
|
}
|
||||||
req.RemoteAddr = tt.remote
|
req.RemoteAddr = tt.remote
|
||||||
|
|
||||||
if got := extractIP(req); got != tt.expected {
|
if got := extractIP(req, tt.trusted...); got != tt.expected {
|
||||||
t.Errorf("extractIP() = %q, want %q", got, tt.expected)
|
t.Errorf("extractIP() = %q, want %q", got, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustParseCIDR(s string) *net.IPNet {
|
||||||
|
_, network, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return network
|
||||||
|
}
|
||||||
|
|
|
||||||
92
internal/middleware/security.go
Normal file
92
internal/middleware/security.go
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
// kafka — a privacy-respecting metasearch engine
|
||||||
|
// Copyright (C) 2026-present metamorphosis-dev
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecurityHeadersConfig controls which security headers are set.
|
||||||
|
type SecurityHeadersConfig struct {
|
||||||
|
// FrameOptions controls X-Frame-Options. Default: "DENY".
|
||||||
|
FrameOptions string
|
||||||
|
// HSTSMaxAge controls the max-age for Strict-Transport-Security.
|
||||||
|
// Set to 0 to disable HSTS (useful for local dev). Default: 31536000 (1 year).
|
||||||
|
HSTSMaxAge int
|
||||||
|
// HSTSPreloadDomains adds "includeSubDomains; preload" to HSTS.
|
||||||
|
HSTSPreloadDomains bool
|
||||||
|
// ReferrerPolicy controls the Referrer-Policy header. Default: "no-referrer".
|
||||||
|
ReferrerPolicy string
|
||||||
|
// CSP controls Content-Security-Policy. Default: a restrictive policy.
|
||||||
|
// Set to "" to disable CSP entirely.
|
||||||
|
CSP string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecurityHeaders returns middleware that sets standard HTTP security headers
|
||||||
|
// on every response.
|
||||||
|
func SecurityHeaders(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
|
||||||
|
frameOpts := cfg.FrameOptions
|
||||||
|
if frameOpts == "" {
|
||||||
|
frameOpts = "DENY"
|
||||||
|
}
|
||||||
|
|
||||||
|
hstsAge := cfg.HSTSMaxAge
|
||||||
|
if hstsAge == 0 {
|
||||||
|
hstsAge = 31536000 // 1 year
|
||||||
|
}
|
||||||
|
|
||||||
|
refPol := cfg.ReferrerPolicy
|
||||||
|
if refPol == "" {
|
||||||
|
refPol = "no-referrer"
|
||||||
|
}
|
||||||
|
|
||||||
|
csp := cfg.CSP
|
||||||
|
if csp == "" {
|
||||||
|
csp = defaultCSP()
|
||||||
|
}
|
||||||
|
|
||||||
|
hstsValue := "max-age=" + strconv.Itoa(hstsAge)
|
||||||
|
if cfg.HSTSPreloadDomains {
|
||||||
|
hstsValue += "; includeSubDomains; preload"
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
w.Header().Set("X-Frame-Options", frameOpts)
|
||||||
|
w.Header().Set("Referrer-Policy", refPol)
|
||||||
|
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
||||||
|
w.Header().Set("Content-Security-Policy", csp)
|
||||||
|
|
||||||
|
if hstsAge > 0 {
|
||||||
|
w.Header().Set("Strict-Transport-Security", hstsValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultCSP returns a restrictive Content-Security-Policy for the
|
||||||
|
// metasearch engine.
|
||||||
|
func defaultCSP() string {
|
||||||
|
return strings.Join([]string{
|
||||||
|
"default-src 'self'",
|
||||||
|
"script-src 'self'",
|
||||||
|
"style-src 'self' 'unsafe-inline'",
|
||||||
|
"img-src 'self' https: data:",
|
||||||
|
"connect-src 'self'",
|
||||||
|
"font-src 'self'",
|
||||||
|
"frame-ancestors 'none'",
|
||||||
|
"base-uri 'self'",
|
||||||
|
"form-action 'self'",
|
||||||
|
}, "; ")
|
||||||
|
}
|
||||||
|
|
@ -26,6 +26,28 @@ import (
|
||||||
|
|
||||||
var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`)
|
var languageCodeRe = regexp.MustCompile(`^[a-z]{2,3}(-[a-zA-Z]{2})?$`)
|
||||||
|
|
||||||
|
// maxQueryLength is the maximum allowed length for the search query.
|
||||||
|
const maxQueryLength = 1024
|
||||||
|
|
||||||
|
// knownEngineNames is the allowlist of valid engine identifiers.
|
||||||
|
var knownEngineNames = map[string]bool{
|
||||||
|
"wikipedia": true, "arxiv": true, "crossref": true,
|
||||||
|
"braveapi": true, "brave": true, "qwant": true,
|
||||||
|
"duckduckgo": true, "github": true, "reddit": true,
|
||||||
|
"bing": true, "google": true, "youtube": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateEngines filters engine names against the known registry.
|
||||||
|
func validateEngines(engines []string) []string {
|
||||||
|
out := make([]string, 0, len(engines))
|
||||||
|
for _, e := range engines {
|
||||||
|
if knownEngineNames[strings.ToLower(e)] {
|
||||||
|
out = append(out, strings.ToLower(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
|
func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
|
||||||
// Supports both GET and POST and relies on form values for routing.
|
// Supports both GET and POST and relies on form values for routing.
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
|
|
@ -50,6 +72,9 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
|
||||||
if strings.TrimSpace(q) == "" {
|
if strings.TrimSpace(q) == "" {
|
||||||
return SearchRequest{}, errors.New("missing required parameter: q")
|
return SearchRequest{}, errors.New("missing required parameter: q")
|
||||||
}
|
}
|
||||||
|
if len(q) > maxQueryLength {
|
||||||
|
return SearchRequest{}, errors.New("query exceeds maximum length")
|
||||||
|
}
|
||||||
|
|
||||||
pageno := 1
|
pageno := 1
|
||||||
if s := strings.TrimSpace(r.FormValue("pageno")); s != "" {
|
if s := strings.TrimSpace(r.FormValue("pageno")); s != "" {
|
||||||
|
|
@ -105,6 +130,8 @@ func ParseSearchRequest(r *http.Request) (SearchRequest, error) {
|
||||||
|
|
||||||
// engines is an explicit list of engine names.
|
// engines is an explicit list of engine names.
|
||||||
engines := splitCSV(strings.TrimSpace(r.FormValue("engines")))
|
engines := splitCSV(strings.TrimSpace(r.FormValue("engines")))
|
||||||
|
// Validate engine names against known registry to prevent injection.
|
||||||
|
engines = validateEngines(engines)
|
||||||
|
|
||||||
// categories and category_<name> params mirror the webadapter parsing.
|
// categories and category_<name> params mirror the webadapter parsing.
|
||||||
// We don't validate against a registry here; we just preserve the requested values.
|
// We don't validate against a registry here; we just preserve the requested values.
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,9 @@ func NewClient(baseURL string, timeout time.Duration) (*Client, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid upstream base URL: %w", err)
|
return nil, fmt.Errorf("invalid upstream base URL: %w", err)
|
||||||
}
|
}
|
||||||
|
if u.Scheme != "http" && u.Scheme != "https" {
|
||||||
|
return nil, fmt.Errorf("upstream URL must use http or https, got %q", u.Scheme)
|
||||||
|
}
|
||||||
// Normalize: trim trailing slash to make URL concatenation predictable.
|
// Normalize: trim trailing slash to make URL concatenation predictable.
|
||||||
base := strings.TrimRight(u.String(), "/")
|
base := strings.TrimRight(u.String(), "/")
|
||||||
|
|
||||||
|
|
@ -108,7 +111,7 @@ func (c *Client) SearchJSON(ctx context.Context, req contracts.SearchRequest, en
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return contracts.SearchResponse{}, fmt.Errorf("upstream search failed: status=%d body=%q", resp.StatusCode, string(body))
|
return contracts.SearchResponse{}, fmt.Errorf("upstream search failed with status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode upstream JSON into our contract types.
|
// Decode upstream JSON into our contract types.
|
||||||
|
|
|
||||||
127
internal/util/validate.go
Normal file
127
internal/util/validate.go
Normal file
|
|
@ -0,0 +1,127 @@
|
||||||
|
// kafka — a privacy-respecting metasearch engine
|
||||||
|
// Copyright (C) 2026-present metamorphosis-dev
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License,// or (at your option) any later version.
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SafeURLScheme validates that a URL is well-formed and uses an acceptable scheme.
|
||||||
|
// Returns the parsed URL on success, or an error.
|
||||||
|
func SafeURLScheme(raw string) (*url.URL, error) {
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if u.Scheme != "http" && u.Scheme != "https" {
|
||||||
|
return nil, fmt.Errorf("URL must use http or https, got %q", u.Scheme)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrivateIP returns true if the IP address is in a private, loopback,
|
||||||
|
// link-local, or otherwise non-routable range.
|
||||||
|
func IsPrivateIP(host string) bool {
|
||||||
|
// Strip port if present.
|
||||||
|
h, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
h = host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve hostname to IPs.
|
||||||
|
ips, err := net.LookupIP(h)
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
// If we can't resolve, reject to be safe.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
if isPrivateIPAddr(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPrivateIPAddr(ip net.IP) bool {
|
||||||
|
privateRanges := []struct {
|
||||||
|
network *net.IPNet
|
||||||
|
}{
|
||||||
|
// Loopback
|
||||||
|
{mustParseCIDR("127.0.0.0/8")},
|
||||||
|
{mustParseCIDR("::1/128")},
|
||||||
|
// RFC 1918
|
||||||
|
{mustParseCIDR("10.0.0.0/8")},
|
||||||
|
{mustParseCIDR("172.16.0.0/12")},
|
||||||
|
{mustParseCIDR("192.168.0.0/16")},
|
||||||
|
// RFC 6598 (Carrier-grade NAT)
|
||||||
|
{mustParseCIDR("100.64.0.0/10")},
|
||||||
|
// Link-local
|
||||||
|
{mustParseCIDR("169.254.0.0/16")},
|
||||||
|
{mustParseCIDR("fe80::/10")},
|
||||||
|
// IPv6 unique local
|
||||||
|
{mustParseCIDR("fc00::/7")},
|
||||||
|
// IPv4-mapped IPv6 loopback
|
||||||
|
{mustParseCIDR("::ffff:127.0.0.0/104")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range privateRanges {
|
||||||
|
if r.network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseCIDR(s string) *net.IPNet {
|
||||||
|
_, network, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("validate: invalid CIDR %q: %v", s, err))
|
||||||
|
}
|
||||||
|
return network
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePublicURL checks that a URL is well-formed, uses http or https,
|
||||||
|
// and does not point to a private/reserved IP range.
|
||||||
|
func ValidatePublicURL(raw string) error {
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid URL: %w", err)
|
||||||
|
}
|
||||||
|
if u.Scheme != "http" && u.Scheme != "https" {
|
||||||
|
return fmt.Errorf("URL must use http or https, got %q", u.Scheme)
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
return fmt.Errorf("URL must have a host")
|
||||||
|
}
|
||||||
|
if IsPrivateIP(u.Host) {
|
||||||
|
return fmt.Errorf("URL points to a private or reserved address: %s", u.Host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeResultURL ensures a URL is safe for rendering in an href attribute.
|
||||||
|
// It rejects javascript:, data:, vbscript: and other dangerous schemes.
|
||||||
|
func SanitizeResultURL(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch strings.ToLower(u.Scheme) {
|
||||||
|
case "http", "https", "":
|
||||||
|
return raw
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -18,6 +18,7 @@ package views
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
|
"encoding/xml"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -25,6 +26,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/metamorphosis-dev/kafka/internal/contracts"
|
"github.com/metamorphosis-dev/kafka/internal/contracts"
|
||||||
|
"github.com/metamorphosis-dev/kafka/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed all:templates
|
//go:embed all:templates
|
||||||
|
|
@ -122,15 +124,20 @@ func StaticFS() (fs.FS, error) {
|
||||||
return fs.Sub(staticFS, "static")
|
return fs.Sub(staticFS, "static")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenSearchXML returns the OpenSearch description XML with {baseUrl}
|
// OpenSearchXML returns the OpenSearch description XML with the base URL
|
||||||
// replaced by the provided base URL.
|
// safely embedded via xml.EscapeText (no raw string interpolation).
|
||||||
func OpenSearchXML(baseURL string) ([]byte, error) {
|
func OpenSearchXML(baseURL string) ([]byte, error) {
|
||||||
tmplFS, _ := fs.Sub(templatesFS, "templates")
|
tmplFS, _ := fs.Sub(templatesFS, "templates")
|
||||||
data, err := fs.ReadFile(tmplFS, "opensearch.xml")
|
data, err := fs.ReadFile(tmplFS, "opensearch.xml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result := strings.ReplaceAll(string(data), "{baseUrl}", baseURL)
|
|
||||||
|
var buf strings.Builder
|
||||||
|
xml.Escape(&buf, []byte(baseURL))
|
||||||
|
escapedBaseURL := buf.String()
|
||||||
|
|
||||||
|
result := strings.ReplaceAll(string(data), "{baseUrl}", escapedBaseURL)
|
||||||
return []byte(result), nil
|
return []byte(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -190,6 +197,12 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
|
||||||
if r.Template == "videos" {
|
if r.Template == "videos" {
|
||||||
tmplName = "video_item"
|
tmplName = "video_item"
|
||||||
}
|
}
|
||||||
|
// Sanitize URLs to prevent javascript:/data: scheme injection.
|
||||||
|
if r.URL != nil {
|
||||||
|
safe := util.SanitizeResultURL(*r.URL)
|
||||||
|
r.URL = &safe
|
||||||
|
}
|
||||||
|
r.Thumbnail = util.SanitizeResultURL(r.Thumbnail)
|
||||||
pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName}
|
pd.Results[i] = ResultView{MainResult: r, TemplateName: tmplName}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -213,7 +226,7 @@ func FromResponse(resp contracts.SearchResponse, query string, pageno int, activ
|
||||||
iv.Title = v
|
iv.Title = v
|
||||||
}
|
}
|
||||||
if v, ok := ib["img_src"].(string); ok {
|
if v, ok := ib["img_src"].(string); ok {
|
||||||
iv.ImgSrc = v
|
iv.ImgSrc = util.SanitizeResultURL(v)
|
||||||
}
|
}
|
||||||
if iv.Title != "" || iv.Content != "" {
|
if iv.Title != "" || iv.Content != "" {
|
||||||
pd.Infoboxes = append(pd.Infoboxes, iv)
|
pd.Infoboxes = append(pd.Infoboxes, iv)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue