add allow & deny model list, fix cors on error

This commit is contained in:
2024-01-16 17:11:25 +08:00
parent 8672899a58
commit b1ab9c3d7b
5 changed files with 55 additions and 20 deletions

View File

@@ -11,6 +11,10 @@ dbaddr: ./db.sqlite
upstreams: upstreams:
- sk: "secret_key_1" - sk: "secret_key_1"
endpoint: "https://api.openai.com/v2" endpoint: "https://api.openai.com/v2"
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
- sk: "secret_key_2" - sk: "secret_key_2"
endpoint: "https://api.openai.com/v1" endpoint: "https://api.openai.com/v1"
timeout: 30 timeout: 30
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
deny: ["gpt-4"] # 可选的模型黑名单
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单

View File

@@ -20,3 +20,9 @@ func corsMiddleware() gin.HandlerFunc {
} }
} }
} }
func sendCORSHeaders(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}

10
main.go
View File

@@ -110,20 +110,13 @@ func main() {
Authorization: c.Request.Header.Get("Authorization"), Authorization: c.Request.Header.Get("Authorization"),
UserAgent: c.Request.Header.Get("User-Agent"), UserAgent: c.Request.Header.Get("User-Agent"),
} }
/*
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
*/
// check authorization header // check authorization header
if !*noauth { if !*noauth {
err := handleAuth(c) err := handleAuth(c)
if err != nil { if err != nil {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
c.AbortWithError(403, err) c.AbortWithError(403, err)
return return
} }
@@ -131,6 +124,7 @@ func main() {
for index, upstream := range config.Upstreams { for index, upstream := range config.Upstreams {
if upstream.Endpoint == "" || upstream.SK == "" { if upstream.Endpoint == "" || upstream.SK == "" {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
continue continue
} }

View File

@@ -18,7 +18,7 @@ import (
) )
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error var errCtx []error
record.UpstreamEndpoint = upstream.Endpoint record.UpstreamEndpoint = upstream.Endpoint
record.UpstreamSK = upstream.SK record.UpstreamSK = upstream.SK
@@ -52,7 +52,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// read request body // read request body
inBody, err = io.ReadAll(in.Body) inBody, err = io.ReadAll(in.Body)
if err != nil { if err != nil {
errCtx = errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body " + err.Error()) errCtx = append(errCtx, errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body "+err.Error()))
return return
} }
@@ -62,6 +62,29 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// record if parse success // record if parse success
if requestBodyOK == nil { if requestBodyOK == nil {
record.Model = requestBody.Model record.Model = requestBody.Model
// check allow list
if len(upstream.Allow) > 0 {
isAllow := false
for _, allow := range upstream.Allow {
if allow == requestBody.Model {
isAllow = true
break
}
}
if !isAllow {
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model not allowed"))
return
}
}
// check block list
if len(upstream.Deny) > 0 {
for _, deny := range upstream.Deny {
if deny == requestBody.Model {
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model denied"))
return
}
}
}
} }
// set timeout, default is 60 second // set timeout, default is 60 second
@@ -82,10 +105,12 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
time.Sleep(timeout) time.Sleep(timeout)
if !haveResponse { if !haveResponse {
log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout) log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout)
errCtx = errors.New("timeout") errTimeout := errors.New("[proxy.timeout]: Timeout upstream")
errCtx = append(errCtx, errTimeout)
if shouldResponse { if shouldResponse {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
c.AbortWithError(502, errCtx) sendCORSHeaders(c)
c.AbortWithError(502, errTimeout)
} }
cancel() cancel()
} }
@@ -150,13 +175,16 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
record.ResponseTime = time.Now().Sub(record.CreatedAt) record.ResponseTime = time.Now().Sub(record.CreatedAt)
log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint) log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint)
errCtx = err errCtx = append(errCtx, err)
// abort to error handle // abort to error handle
if shouldResponse { if shouldResponse {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
for _, err := range errCtx {
c.AbortWithError(502, err) c.AbortWithError(502, err)
} }
}
log.Println("[proxy.errorHandler]: response is", r.Response) log.Println("[proxy.errorHandler]: response is", r.Response)
@@ -180,10 +208,11 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
} }
// return context error // return context error
if errCtx != nil { if len(errCtx) > 0 {
log.Println("[proxy.serve]: error from ServeHTTP:", errCtx)
// fix inrequest body // fix inrequest body
c.Request.Body = io.NopCloser(bytes.NewReader(inBody)) c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
return errCtx return errCtx[len(errCtx)-1]
} }
resp, err := io.ReadAll(io.NopCloser(&buf)) resp, err := io.ReadAll(io.NopCloser(&buf))

View File

@@ -20,6 +20,8 @@ type OPENAI_UPSTREAM struct {
SK string `yaml:"sk"` SK string `yaml:"sk"`
Endpoint string `yaml:"endpoint"` Endpoint string `yaml:"endpoint"`
Timeout int64 `yaml:"timeout"` Timeout int64 `yaml:"timeout"`
Allow []string `yaml:"allow"`
Deny []string `yaml:"deny"`
URL *url.URL URL *url.URL
} }