- Create package-level httpClient with 300s timeout\n- Reuse client instead of creating new one per request\n- Prevents resource exhaustion under load\n- Reduces connection overhead
248 lines
7.5 KiB
Go
248 lines
7.5 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Config holds the application configuration
|
|
type Config struct {
|
|
Port int `yaml:"port"`
|
|
UpstreamURL string `yaml:"upstream_url"`
|
|
}
|
|
|
|
var config *Config
|
|
|
|
// httpClient is a shared HTTP client for all upstream requests
|
|
var httpClient = &http.Client{Timeout: 300 * time.Second}
|
|
|
|
// blockedHeaders are headers that should never be forwarded to upstream
|
|
// for security/privacy reasons. These headers could leak internal URLs,
|
|
// session information, or other sensitive data.
|
|
var blockedHeaders = map[string]bool{
|
|
"Referer": true, // Don't leak internal URLs to external API
|
|
"Cookie": true, // Don't forward session cookies
|
|
"Authorization": true, // Already extracted and sent as x-api-key
|
|
"X-Forwarded-For": true, // Don't leak client IP
|
|
"X-Real-Ip": true, // Don't leak client IP
|
|
"X-Forwarded-Host": true, // Don't leak internal hostnames
|
|
}
|
|
|
|
// ClaudeCodeHeaders returns the headers to mimic claude-code CLI
|
|
func ClaudeCodeHeaders(apiKey, sessionID string) map[string]string {
|
|
return map[string]string{
|
|
"User-Agent": "claude-cli/1.0.18 (pro, cli)",
|
|
"x-api-key": apiKey,
|
|
"x-app": "cli",
|
|
"anthropic-version": "2023-06-01",
|
|
"anthropic-beta": "claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05,context-management-2025-06-27,redact-thinking-2026-02-12",
|
|
"X-Claude-Code-Session-Id": sessionID,
|
|
"content-type": "application/json",
|
|
}
|
|
}
|
|
|
|
func handleModels(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
writeError(w, http.StatusMethodNotAllowed, "Method not allowed", "invalid_request_error", "method_not_allowed")
|
|
return
|
|
}
|
|
|
|
models := []map[string]interface{}{
|
|
{"id": "glm-4.7", "object": "model", "created": 1234567890, "owned_by": "zhipu"},
|
|
{"id": "glm-4.6", "object": "model", "created": 1234567890, "owned_by": "zhipu"},
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"object": "list",
|
|
"data": models,
|
|
}
|
|
|
|
w.Header().Set("content-type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeError(w, http.StatusMethodNotAllowed, "Method not allowed", "invalid_request_error", "method_not_allowed")
|
|
return
|
|
}
|
|
|
|
// Extract Bearer token
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
|
writeError(w, http.StatusUnauthorized, "Missing or invalid Authorization header", "authentication_error", "missing_authorization")
|
|
return
|
|
}
|
|
apiKey := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
// Read body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "Failed to read request body", "invalid_request_error", "body_read_error")
|
|
return
|
|
}
|
|
|
|
// Decode request
|
|
var req ChatCompletionRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "Invalid JSON in request body", "invalid_request_error", "json_decode_error")
|
|
return
|
|
}
|
|
|
|
// Get session ID from context (set by main)
|
|
sessionID := r.Context().Value(sessionIDKey).(string)
|
|
|
|
// Convert to Anthropic format — always non-streaming to upstream
|
|
// (ZAI's streaming returns empty for GLM models)
|
|
anthropicReq := ConvertOpenAIRequest(&req)
|
|
anthropicReq.Stream = false
|
|
|
|
reqBody, _ := json.Marshal(anthropicReq)
|
|
log.Printf("[debug] Sending to upstream %s, model=%s, body=%s", config.UpstreamURL, req.Model, string(reqBody))
|
|
|
|
// Non-streaming request to upstream
|
|
upstreamResp, err := callUpstream(anthropicReq, apiKey, sessionID)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err), "upstream_error", "proxy_error")
|
|
return
|
|
}
|
|
defer upstreamResp.Body.Close()
|
|
|
|
if upstreamResp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(upstreamResp.Body)
|
|
log.Printf("[debug] Upstream error status %d: %s", upstreamResp.StatusCode, string(respBody))
|
|
writeError(w, http.StatusBadGateway, fmt.Sprintf("Upstream returned error: %s", string(respBody)), "upstream_error", fmt.Sprintf("status_%d", upstreamResp.StatusCode))
|
|
return
|
|
}
|
|
|
|
// Read the full Anthropic response
|
|
respBody, err := io.ReadAll(upstreamResp.Body)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadGateway, "Failed to read upstream response", "upstream_error", "body_read_error")
|
|
return
|
|
}
|
|
log.Printf("[debug] Upstream response: %s", string(respBody))
|
|
|
|
var anthropicResp AnthropicResponse
|
|
if err := json.Unmarshal(respBody, &anthropicResp); err != nil {
|
|
writeError(w, http.StatusBadGateway, "Failed to parse upstream response", "upstream_error", "json_decode_error")
|
|
return
|
|
}
|
|
|
|
isStream := req.Stream != nil && *req.Stream
|
|
|
|
if isStream {
|
|
// Convert the non-streaming response to SSE chunks for the client
|
|
w.Header().Set("content-type", "text/event-stream")
|
|
w.Header().Set("cache-control", "no-cache")
|
|
w.Header().Set("connection", "keep-alive")
|
|
|
|
created := time.Now().Unix()
|
|
chunkID := "chatcmpl-" + randomString(8)
|
|
|
|
// Extract text content
|
|
var textContent string
|
|
for _, block := range anthropicResp.Content {
|
|
if block.Type == "text" {
|
|
textContent += block.Text
|
|
}
|
|
}
|
|
|
|
// Send text as chunks (simulate streaming)
|
|
if textContent != "" {
|
|
chunk := StreamChunk{
|
|
ID: chunkID,
|
|
Object: "chat.completion.chunk",
|
|
Created: created,
|
|
Model: req.Model,
|
|
Choices: []StreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: Delta{
|
|
Role: "assistant",
|
|
Content: textContent,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
w.(http.Flusher).Flush()
|
|
}
|
|
|
|
// Send finish chunk
|
|
finishChunk := StreamChunk{
|
|
ID: chunkID,
|
|
Object: "chat.completion.chunk",
|
|
Created: created,
|
|
Model: req.Model,
|
|
Choices: []StreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: Delta{},
|
|
FinishReason: mapStopReason(anthropicResp.StopReason),
|
|
},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(finishChunk)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
w.(http.Flusher).Flush()
|
|
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
w.(http.Flusher).Flush()
|
|
} else {
|
|
// Non-streaming: convert directly
|
|
openAIResp := ConvertAnthropicResponse(&anthropicResp, req.Model)
|
|
w.Header().Set("content-type", "application/json")
|
|
json.NewEncoder(w).Encode(openAIResp)
|
|
}
|
|
}
|
|
|
|
func callUpstream(req *AnthropicRequest, apiKey, sessionID string) (*http.Response, error) {
|
|
bodyBytes, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
upstreamURL := config.UpstreamURL + "/v1/messages"
|
|
httpReq, err := http.NewRequest(http.MethodPost, upstreamURL, strings.NewReader(string(bodyBytes)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
headers := ClaudeCodeHeaders(apiKey, sessionID)
|
|
for k, v := range headers {
|
|
httpReq.Header.Set(k, v)
|
|
}
|
|
|
|
return httpClient.Do(httpReq)
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, code int, message, errType, errCode string) {
|
|
w.Header().Set("content-type", "application/json")
|
|
w.WriteHeader(code)
|
|
resp := map[string]interface{}{
|
|
"error": map[string]string{
|
|
"message": message,
|
|
"type": errType,
|
|
"code": errCode,
|
|
},
|
|
}
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func randomString(n int) string {
|
|
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
b := make([]byte, n)
|
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
for i := range b {
|
|
b[i] = letters[r.Intn(len(letters))]
|
|
}
|
|
return string(b)
|
|
}
|