package main import ( "context" "fmt" "log" "net/http" "os" "os/signal" "syscall" "time" "github.com/google/uuid" "gopkg.in/yaml.v3" ) func main() { // Load config.yaml data, err := os.ReadFile("config.yaml") if err != nil { log.Fatalf("Failed to read config.yaml: %v", err) } cfg := &Config{} if err := yaml.Unmarshal(data, cfg); err != nil { log.Fatalf("Failed to parse config.yaml: %v", err) } config = cfg // Generate session ID (persist across requests) sessionID := uuid.New().String() // Register routes http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(contextWithSessionIDInContext(r.Context(), sessionID)) handleChatCompletions(w, r) }) http.HandleFunc("/v1/models", handleModels) http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, "OK") }) addr := fmt.Sprintf(":%d", config.Port) server := &http.Server{Addr: addr} go func() { log.Printf("Starting proxx on %s, upstream: %s", addr, config.UpstreamURL) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server failed: %v", err) } }() // Wait for interrupt signal sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) <-sigChan log.Println("Shutting down server...") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { log.Printf("Server shutdown error: %v", err) } log.Println("Server stopped") } // contextKey is a custom type for context keys type contextKey string const sessionIDKey contextKey = "sessionID" // contextWithSessionID creates a context with the session ID func contextWithSessionID(sessionID string) context.Context { return context.WithValue(nil, sessionIDKey, sessionID) } // contextWithSessionIDInContext creates a new context with session ID in existing context func contextWithSessionIDInContext(parent context.Context, sessionID string) context.Context { return context.WithValue(parent, sessionIDKey, sessionID) }