From b1ab9c3d7b0866daa99e5b21e7d558e4df64f79c Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 16 Jan 2024 17:11:25 +0800 Subject: [PATCH] add allow & deny model list, fix cors on error --- config.sample.yaml | 6 +++++- cors.go | 6 ++++++ main.go | 10 ++-------- process.go | 45 +++++++++++++++++++++++++++++++++++++-------- structure.go | 8 +++++--- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/config.sample.yaml b/config.sample.yaml index 4971081..fac7a98 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -11,6 +11,10 @@ dbaddr: ./db.sqlite upstreams: - sk: "secret_key_1" endpoint: "https://api.openai.com/v2" + allow: ["gpt-3.5-trubo"] # 可选的模型白名单 - sk: "secret_key_2" endpoint: "https://api.openai.com/v1" - timeout: 30 \ No newline at end of file + timeout: 30 + allow: ["gpt-3.5-trubo"] # 可选的模型白名单 + deny: ["gpt-4"] # 可选的模型黑名单 + # 若白名单和黑名单同时设置,先判断白名单,再判断黑名单 \ No newline at end of file diff --git a/cors.go b/cors.go index 5a68e02..a863df6 100644 --- a/cors.go +++ b/cors.go @@ -20,3 +20,9 @@ func corsMiddleware() gin.HandlerFunc { } } } + +func sendCORSHeaders(c *gin.Context) { + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH") + c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type") +} diff --git a/main.go b/main.go index 6e0a0f4..d43e850 100644 --- a/main.go +++ b/main.go @@ -110,20 +110,13 @@ func main() { Authorization: c.Request.Header.Get("Authorization"), UserAgent: c.Request.Header.Get("User-Agent"), } - /* - defer func() { - if err := recover(); err != nil { - log.Println("Error:", err) - c.AbortWithError(500, fmt.Errorf("%s", err)) - } - }() - */ // check authorization header if !*noauth { err := handleAuth(c) if err != nil { c.Header("Content-Type", "application/json") + sendCORSHeaders(c) c.AbortWithError(403, err) return } @@ -131,6 +124,7 @@ func main() { for index, upstream := range config.Upstreams { if upstream.Endpoint == "" || upstream.SK == "" { + sendCORSHeaders(c) c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) continue } diff --git a/process.go b/process.go index 3f6af48..739311b 100644 --- a/process.go +++ b/process.go @@ -18,7 +18,7 @@ import ( ) func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { - var errCtx error + var errCtx []error record.UpstreamEndpoint = upstream.Endpoint record.UpstreamSK = upstream.SK @@ -52,7 +52,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s // read request body inBody, err = io.ReadAll(in.Body) if err != nil { - errCtx = errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body " + err.Error()) + errCtx = append(errCtx, errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body "+err.Error())) return } @@ -62,6 +62,29 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s // record if parse success if requestBodyOK == nil { record.Model = requestBody.Model + // check allow list + if len(upstream.Allow) > 0 { + isAllow := false + for _, allow := range upstream.Allow { + if allow == requestBody.Model { + isAllow = true + break + } + } + if !isAllow { + errCtx = append(errCtx, errors.New("[proxy.rewrite]: model not allowed")) + return + } + } + // check block list + if len(upstream.Deny) > 0 { + for _, deny := range upstream.Deny { + if deny == requestBody.Model { + errCtx = append(errCtx, errors.New("[proxy.rewrite]: model denied")) + return + } + } + } } // set timeout, default is 60 second @@ -82,10 +105,12 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s time.Sleep(timeout) if !haveResponse { log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout) - errCtx = errors.New("timeout") + errTimeout := errors.New("[proxy.timeout]: Timeout upstream") + errCtx = append(errCtx, errTimeout) if shouldResponse { c.Header("Content-Type", "application/json") - c.AbortWithError(502, errCtx) + sendCORSHeaders(c) + c.AbortWithError(502, errTimeout) } cancel() } @@ -150,12 +175,15 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s record.ResponseTime = time.Now().Sub(record.CreatedAt) log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint) - errCtx = err + errCtx = append(errCtx, err) // abort to error handle if shouldResponse { c.Header("Content-Type", "application/json") - c.AbortWithError(502, err) + sendCORSHeaders(c) + for _, err := range errCtx { + c.AbortWithError(502, err) + } } log.Println("[proxy.errorHandler]: response is", r.Response) @@ -180,10 +208,11 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s } // return context error - if errCtx != nil { + if len(errCtx) > 0 { + log.Println("[proxy.serve]: error from ServeHTTP:", errCtx) // fix inrequest body c.Request.Body = io.NopCloser(bytes.NewReader(inBody)) - return errCtx + return errCtx[len(errCtx)-1] } resp, err := io.ReadAll(io.NopCloser(&buf)) diff --git a/structure.go b/structure.go index 3d5b649..d1ffe6d 100644 --- a/structure.go +++ b/structure.go @@ -17,9 +17,11 @@ type Config struct { Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"` } type OPENAI_UPSTREAM struct { - SK string `yaml:"sk"` - Endpoint string `yaml:"endpoint"` - Timeout int64 `yaml:"timeout"` + SK string `yaml:"sk"` + Endpoint string `yaml:"endpoint"` + Timeout int64 `yaml:"timeout"` + Allow []string `yaml:"allow"` + Deny []string `yaml:"deny"` URL *url.URL }