Empty CORSConfig now means no CORS headers, matching the security fix. Test explicitly configures an origin to test preflight behavior.
116 lines
3.9 KiB
Go
116 lines
3.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestCORS_WildcardOrigin(t *testing.T) {
|
|
h := CORS(CORSConfig{AllowedOrigins: []string{"*"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/search?q=test", nil)
|
|
req.Header.Set("Origin", "https://evil.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rec.Code)
|
|
}
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
|
|
t.Errorf("expected wildcard origin, got %s", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORS_SpecificOrigin(t *testing.T) {
|
|
h := CORS(CORSConfig{AllowedOrigins: []string{"https://example.com"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Allowed origin.
|
|
req := httptest.NewRequest("GET", "/search?q=test", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
|
|
t.Errorf("expected https://example.com, got %s", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
// Disallowed origin — header should not be set.
|
|
req2 := httptest.NewRequest("GET", "/search?q=test", nil)
|
|
req2.Header.Set("Origin", "https://evil.com")
|
|
rec2 := httptest.NewRecorder()
|
|
h.ServeHTTP(rec2, req2)
|
|
|
|
if rec2.Header().Get("Access-Control-Allow-Origin") != "" {
|
|
t.Errorf("expected no CORS header for disallowed origin, got %s", rec2.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORS_Preflight(t *testing.T) {
|
|
h := CORS(CORSConfig{AllowedOrigins: []string{"https://example.com"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called for preflight")
|
|
}))
|
|
|
|
req := httptest.NewRequest("OPTIONS", "/search", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Errorf("expected 204 for preflight, got %d", rec.Code)
|
|
}
|
|
if rec.Header().Get("Access-Control-Allow-Methods") == "" {
|
|
t.Error("expected Access-Control-Allow-Methods header")
|
|
}
|
|
if rec.Header().Get("Access-Control-Max-Age") != "3600" {
|
|
t.Errorf("expected Max-Age 3600, got %s", rec.Header().Get("Access-Control-Max-Age"))
|
|
}
|
|
}
|
|
|
|
func TestCORS_NoOriginHeader(t *testing.T) {
|
|
called := false
|
|
h := CORS(CORSConfig{AllowedOrigins: []string{"https://example.com"}})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// No Origin header — should pass through without CORS headers.
|
|
req := httptest.NewRequest("GET", "/search?q=test", nil)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if !called {
|
|
t.Error("handler should be called")
|
|
}
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "" {
|
|
t.Errorf("expected no CORS header without Origin, got %s", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORS_CustomMethodsAndHeaders(t *testing.T) {
|
|
h := CORS(CORSConfig{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedMethods: []string{"GET"},
|
|
AllowedHeaders: []string{"X-Custom"},
|
|
MaxAge: 7200,
|
|
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
|
|
req := httptest.NewRequest("OPTIONS", "/search", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Methods") != "GET" {
|
|
t.Errorf("expected 'GET', got %s", rec.Header().Get("Access-Control-Allow-Methods"))
|
|
}
|
|
if rec.Header().Get("Access-Control-Allow-Headers") != "X-Custom" {
|
|
t.Errorf("expected 'X-Custom', got %s", rec.Header().Get("Access-Control-Allow-Headers"))
|
|
}
|
|
if rec.Header().Get("Access-Control-Max-Age") != "7200" {
|
|
t.Errorf("expected '7200', got %s", rec.Header().Get("Access-Control-Max-Age"))
|
|
}
|
|
}
|