package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestCORS_WildcardOrigin(t *testing.T) { h := CORS(CORSConfig{AllowedOrigins: []string{"*"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/search?q=test", nil) req.Header.Set("Origin", "https://evil.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected 200, got %d", rec.Code) } if rec.Header().Get("Access-Control-Allow-Origin") != "*" { t.Errorf("expected wildcard origin, got %s", rec.Header().Get("Access-Control-Allow-Origin")) } } func TestCORS_SpecificOrigin(t *testing.T) { h := CORS(CORSConfig{AllowedOrigins: []string{"https://example.com"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Allowed origin. req := httptest.NewRequest("GET", "/search?q=test", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" { t.Errorf("expected https://example.com, got %s", rec.Header().Get("Access-Control-Allow-Origin")) } // Disallowed origin — header should not be set. req2 := httptest.NewRequest("GET", "/search?q=test", nil) req2.Header.Set("Origin", "https://evil.com") rec2 := httptest.NewRecorder() h.ServeHTTP(rec2, req2) if rec2.Header().Get("Access-Control-Allow-Origin") != "" { t.Errorf("expected no CORS header for disallowed origin, got %s", rec2.Header().Get("Access-Control-Allow-Origin")) } } func TestCORS_Preflight(t *testing.T) { h := CORS(CORSConfig{})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called for preflight") })) req := httptest.NewRequest("OPTIONS", "/search", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Errorf("expected 204 for preflight, got %d", rec.Code) } if rec.Header().Get("Access-Control-Allow-Methods") == "" { t.Error("expected Access-Control-Allow-Methods header") } if rec.Header().Get("Access-Control-Max-Age") != "3600" { t.Errorf("expected Max-Age 3600, got %s", rec.Header().Get("Access-Control-Max-Age")) } } func TestCORS_NoOriginHeader(t *testing.T) { called := false h := CORS(CORSConfig{AllowedOrigins: []string{"https://example.com"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true w.WriteHeader(http.StatusOK) })) // No Origin header — should pass through without CORS headers. req := httptest.NewRequest("GET", "/search?q=test", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if !called { t.Error("handler should be called") } if rec.Header().Get("Access-Control-Allow-Origin") != "" { t.Errorf("expected no CORS header without Origin, got %s", rec.Header().Get("Access-Control-Allow-Origin")) } } func TestCORS_CustomMethodsAndHeaders(t *testing.T) { h := CORS(CORSConfig{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET"}, AllowedHeaders: []string{"X-Custom"}, MaxAge: 7200, })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest("OPTIONS", "/search", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Methods") != "GET" { t.Errorf("expected 'GET', got %s", rec.Header().Get("Access-Control-Allow-Methods")) } if rec.Header().Get("Access-Control-Allow-Headers") != "X-Custom" { t.Errorf("expected 'X-Custom', got %s", rec.Header().Get("Access-Control-Allow-Headers")) } if rec.Header().Get("Access-Control-Max-Age") != "7200" { t.Errorf("expected '7200', got %s", rec.Header().Get("Access-Control-Max-Age")) } }