diff --git a/config.yaml b/config.yaml index 9025674..28beb39 100644 --- a/config.yaml +++ b/config.yaml @@ -1,2 +1,8 @@ port: 8080 upstream_url: "https://api.z.ai/api/anthropic" + +models: + - id: "glm-4.7" + owned_by: "zhipu" + - id: "glm-4.6" + owned_by: "zhipu" diff --git a/converter.go b/converter.go index f82ed72..d7b9bcf 100644 --- a/converter.go +++ b/converter.go @@ -53,8 +53,18 @@ func extractSystemMessage(messages []Message) (string, []Message) { for _, msg := range messages { if msg.Role == "system" { - if content, ok := msg.Content.(string); ok { + switch content := msg.Content.(type) { + case string: systemParts = append(systemParts, content) + case []interface{}: + // Extract text from content array + for _, part := range content { + if partMap, ok := part.(map[string]interface{}); ok { + if text, ok := partMap["text"].(string); ok { + systemParts = append(systemParts, text) + } + } + } } } else { rest = append(rest, msg) diff --git a/handler.go b/handler.go index 63e4b15..4324de8 100644 --- a/handler.go +++ b/handler.go @@ -11,10 +11,16 @@ import ( "time" ) +type ModelConfig struct { + ID string `yaml:"id"` + OwnedBy string `yaml:"owned_by"` +} + // Config holds the application configuration type Config struct { - Port int `yaml:"port"` - UpstreamURL string `yaml:"upstream_url"` + Port int `yaml:"port"` + UpstreamURL string `yaml:"upstream_url"` + Models []ModelConfig `yaml:"models"` } var config *Config @@ -56,9 +62,14 @@ func handleModels(w http.ResponseWriter, r *http.Request) { return } - models := []map[string]interface{}{ - {"id": "glm-4.7", "object": "model", "created": 1234567890, "owned_by": "zhipu"}, - {"id": "glm-4.6", "object": "model", "created": 1234567890, "owned_by": "zhipu"}, + models := make([]map[string]interface{}, len(config.Models)) + for i, model := range config.Models { + models[i] = map[string]interface{}{ + "id": model.ID, + "object": "model", + "created": 1234567890, + "owned_by": model.OwnedBy, + } } response := map[string]interface{}{ @@ -117,7 +128,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) { anthropicReq.Stream = false reqBody, _ := json.Marshal(anthropicReq) - log.Printf("[debug] Sending to upstream %s, model=%s, body=%s", config.UpstreamURL, req.Model, string(reqBody)) + log.Printf("[debug] Sending to upstream %s, model=%s, body_size=%d bytes", config.UpstreamURL, req.Model, len(reqBody)) // Non-streaming request to upstream upstreamResp, err := callUpstream(anthropicReq, apiKey, sessionID) @@ -129,7 +140,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) { if upstreamResp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(upstreamResp.Body) - log.Printf("[debug] Upstream error status %d: %s", upstreamResp.StatusCode, string(respBody)) + log.Printf("[debug] Upstream error status %d, body_size=%d bytes", upstreamResp.StatusCode, len(respBody)) writeError(w, http.StatusBadGateway, fmt.Sprintf("Upstream returned error: %s", string(respBody)), "upstream_error", fmt.Sprintf("status_%d", upstreamResp.StatusCode)) return } @@ -140,7 +151,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadGateway, "Failed to read upstream response", "upstream_error", "body_read_error") return } - log.Printf("[debug] Upstream response: %s", string(respBody)) + log.Printf("[debug] Upstream response received, body_size=%d bytes", len(respBody)) var anthropicResp AnthropicResponse if err := json.Unmarshal(respBody, &anthropicResp); err != nil { diff --git a/main.go b/main.go index 69f422c..8149af5 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,9 @@ import ( "log" "net/http" "os" + "os/signal" + "syscall" + "time" "github.com/google/uuid" "gopkg.in/yaml.v3" @@ -38,10 +41,28 @@ func main() { }) addr := fmt.Sprintf(":%d", config.Port) - log.Printf("Starting proxx on %s, upstream: %s", addr, config.UpstreamURL) - if err := http.ListenAndServe(addr, nil); err != nil { - log.Fatalf("Server failed: %v", err) + + server := &http.Server{Addr: addr} + + go func() { + log.Printf("Starting proxx on %s, upstream: %s", addr, config.UpstreamURL) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Server failed: %v", err) + } + }() + + // Wait for interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + log.Printf("Server shutdown error: %v", err) } + log.Println("Server stopped") } // contextKey is a custom type for context keys diff --git a/proxx b/proxx index 62b35b4..93268a2 100755 Binary files a/proxx and b/proxx differ