diff --git a/auth.go b/auth.go index f3bc7e1..43949f2 100644 --- a/auth.go +++ b/auth.go @@ -2,30 +2,14 @@ package main import ( "errors" - "log" "strings" - - "github.com/gin-gonic/gin" ) -func handleAuth(c *gin.Context) error { - var err error - - authorization := c.Request.Header.Get("Authorization") - if !strings.HasPrefix(authorization, "Bearer") { - err = errors.New("authorization header should start with 'Bearer'") - return err - } - - authorization = strings.Trim(authorization[len("Bearer"):], " ") - log.Println("Received authorization", authorization) - - for _, auth := range strings.Split(config.Authorization, ",") { - if authorization != strings.Trim(auth, " ") { - err = errors.New("wrong authorization header") - return err +func checkAuth(authorization string, config string) error { + for _, auth := range strings.Split(config, ",") { + if authorization == strings.Trim(auth, " ") { + return nil } } - - return nil + return errors.New("wrong authorization header") } diff --git a/main.go b/main.go index ea23e0a..52b4484 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "flag" "fmt" "log" @@ -99,7 +100,7 @@ func main() { }) engine.POST("/v1/*any", func(c *gin.Context) { - hostname, err := os.Hostname() + hostname, _ := os.Hostname() if config.Hostname != "" { hostname = config.Hostname } @@ -112,16 +113,14 @@ func main() { Model: c.Request.URL.Path, } - // check authorization header - if !*noauth { - err := handleAuth(c) - if err != nil { - c.Header("Content-Type", "application/json") - sendCORSHeaders(c) - c.AbortWithError(403, err) - return - } + authorization := c.Request.Header.Get("Authorization") + if strings.HasPrefix(authorization, "Bearer") { + authorization = strings.Trim(authorization[len("Bearer"):], " ") + } else { + authorization = strings.Trim(authorization, " ") + log.Println("[auth] Warning: authorization header should start with 'Bearer'") } + log.Println("Received authorization '" + authorization + "'") for index, upstream := range config.Upstreams { if upstream.SK == "" { @@ -132,6 +131,20 @@ func main() { shouldResponse := index == len(config.Upstreams)-1 + // check authorization header + if !*noauth && !upstream.Noauth { + if checkAuth(authorization, upstream.Authorization) != nil { + if shouldResponse { + c.Header("Content-Type", "application/json") + sendCORSHeaders(c) + c.AbortWithError(403, fmt.Errorf("[processRequest.begin]: wrong authorization header")) + } + log.Println("[auth] Authorization header check failed for", upstream.SK, authorization) + continue + } + log.Println("[auth] Authorization header check pass for", upstream.SK, authorization) + } + if len(config.Upstreams) == 1 { upstream.Timeout = 120 } @@ -160,15 +173,16 @@ func main() { } log.Println("[final]: Record result:", record.Status, record.Response) - record.ElapsedTime = time.Now().Sub(record.CreatedAt) + record.ElapsedTime = time.Since(record.CreatedAt) // async record request go func() { + // encoder headers to record.Headers in json string + headers, _ := json.Marshal(c.Request.Header) + record.Headers = string(headers) + // turncate request if too long - if len(record.Body) > 1024*128 { - log.Println("[async.record]: Warning: Truncate request body") - record.Body = record.Body[:1024*128] - } + log.Println("[async.record]: body length:", len(record.Body)) if db.Create(&record).Error != nil { log.Println("[async.record]: Error to save record:", record) } diff --git a/process.go b/process.go index 2a55c46..48d9dfa 100644 --- a/process.go +++ b/process.go @@ -59,11 +59,11 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s } // record chat message from user - record.Body = string(inBody) requestBody, requestBodyOK := ParseRequestBody(inBody) // record if parse success if requestBodyOK == nil && requestBody.Model != "" { record.Model = requestBody.Model + record.Body = string(inBody) } // check allow list @@ -125,7 +125,9 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s out.URL.Scheme = remote.Scheme out.URL.Host = remote.Host - out.Header = http.Header{} + if !upstream.KeepHeader { + out.Header = http.Header{} + } out.Header.Set("Host", remote.Host) if upstream.SK == "asis" { out.Header.Set("Authorization", c.Request.Header.Get("Authorization")) @@ -138,7 +140,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s var contentType string proxy.ModifyResponse = func(r *http.Response) error { haveResponse = true - record.ResponseTime = time.Now().Sub(record.CreatedAt) + record.ResponseTime = time.Since(record.CreatedAt) record.Status = r.StatusCode // handle reverse proxy cors header if upstream do not set that @@ -163,7 +165,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s errRet := errors.New("[proxy.modifyResponse]: failed to read response from upstream " + err.Error()) return errRet } - errRet := errors.New(fmt.Sprintf("[error]: openai-api-route upstream return '%s' with '%s'", r.Status, string(body))) + errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", r.Status, string(body)) log.Println(errRet) record.Status = r.StatusCode return errRet @@ -175,7 +177,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { haveResponse = true - record.ResponseTime = time.Now().Sub(record.CreatedAt) + record.ResponseTime = time.Since(record.CreatedAt) log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint, errCtx) errCtx = append(errCtx, err) @@ -224,13 +226,6 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s log.Println(record.Response) } else { - // record request body - if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") { - record.Body = string(inBody) - } else { - record.Body = "binary data" - } - // record response // stream mode if strings.HasPrefix(contentType, "text/event-stream") { @@ -256,21 +251,17 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s } else if strings.HasPrefix(contentType, "text") { record.Response = string(resp) } else if strings.HasPrefix(contentType, "application/json") { + // fallback record response + if len(resp) < 1024*128 { + record.Response = string(resp) + } var fetchResp FetchModeResponse err := json.Unmarshal(resp, &fetchResp) - if err != nil { - log.Println("[proxy.parseJSONError]: error parsing fetch response:", err) - return nil + if err == nil { + if len(fetchResp.Choices) > 0 { + record.Response = fetchResp.Choices[0].Message.Content + } } - if !strings.HasPrefix(fetchResp.Model, "gpt-") { - log.Println("[proxy.record]: Not GPT model, skip recording response:", fetchResp.Model) - return nil - } - if len(fetchResp.Choices) == 0 { - log.Println("[proxy.record]: Error: fetch response choice length is 0") - return nil - } - record.Response = fetchResp.Choices[0].Message.Content } else { log.Println("[proxy.record]: Unknown content type", contentType) } diff --git a/record.go b/record.go index 58b053f..1b45c11 100644 --- a/record.go +++ b/record.go @@ -19,6 +19,7 @@ type Record struct { Status int Authorization string // the autorization header send by client UserAgent string + Headers string } type StreamModeChunk struct { diff --git a/replicate.go b/replicate.go index a09f318..0671f24 100644 --- a/replicate.go +++ b/replicate.go @@ -35,7 +35,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record } // record request body - record.Body = string(inBody) // parse request body inRequest := &OpenAIChatRequest{} @@ -357,7 +356,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record FinishReason: "stop", }) - record.Body = strings.Join(result.Output, "") + record.Response = strings.Join(result.Output, "") record.Status = 200 // gin return diff --git a/structure.go b/structure.go index e7ac3c3..67dbeef 100644 --- a/structure.go +++ b/structure.go @@ -17,13 +17,16 @@ type Config struct { Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"` } type OPENAI_UPSTREAM struct { - SK string `yaml:"sk"` - Endpoint string `yaml:"endpoint"` - Timeout int64 `yaml:"timeout"` - Allow []string `yaml:"allow"` - Deny []string `yaml:"deny"` - Type string `yaml:"type"` - URL *url.URL + SK string `yaml:"sk"` + Endpoint string `yaml:"endpoint"` + Timeout int64 `yaml:"timeout"` + Allow []string `yaml:"allow"` + Deny []string `yaml:"deny"` + Type string `yaml:"type"` + KeepHeader bool `yaml:"keep_header"` + Authorization string `yaml:"authorization"` + Noauth bool `yaml:"noauth"` + URL *url.URL } func readConfig(filepath string) Config { @@ -68,6 +71,10 @@ func readConfig(filepath string) Config { if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") { log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type) } + // apply authorization from global config if not set + if config.Upstreams[i].Authorization == "" && !config.Upstreams[i].Noauth { + config.Upstreams[i].Authorization = config.Authorization + } } return config