add allow & deny model list, fix cors on error
This commit is contained in:
@@ -11,6 +11,10 @@ dbaddr: ./db.sqlite
|
|||||||
upstreams:
|
upstreams:
|
||||||
- sk: "secret_key_1"
|
- sk: "secret_key_1"
|
||||||
endpoint: "https://api.openai.com/v2"
|
endpoint: "https://api.openai.com/v2"
|
||||||
|
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||||
- sk: "secret_key_2"
|
- sk: "secret_key_2"
|
||||||
endpoint: "https://api.openai.com/v1"
|
endpoint: "https://api.openai.com/v1"
|
||||||
timeout: 30
|
timeout: 30
|
||||||
|
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||||
|
deny: ["gpt-4"] # 可选的模型黑名单
|
||||||
|
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
|
||||||
6
cors.go
6
cors.go
@@ -20,3 +20,9 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sendCORSHeaders(c *gin.Context) {
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||||
|
}
|
||||||
|
|||||||
10
main.go
10
main.go
@@ -110,20 +110,13 @@ func main() {
|
|||||||
Authorization: c.Request.Header.Get("Authorization"),
|
Authorization: c.Request.Header.Get("Authorization"),
|
||||||
UserAgent: c.Request.Header.Get("User-Agent"),
|
UserAgent: c.Request.Header.Get("User-Agent"),
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
log.Println("Error:", err)
|
|
||||||
c.AbortWithError(500, fmt.Errorf("%s", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
*/
|
|
||||||
|
|
||||||
// check authorization header
|
// check authorization header
|
||||||
if !*noauth {
|
if !*noauth {
|
||||||
err := handleAuth(c)
|
err := handleAuth(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
sendCORSHeaders(c)
|
||||||
c.AbortWithError(403, err)
|
c.AbortWithError(403, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -131,6 +124,7 @@ func main() {
|
|||||||
|
|
||||||
for index, upstream := range config.Upstreams {
|
for index, upstream := range config.Upstreams {
|
||||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||||
|
sendCORSHeaders(c)
|
||||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
43
process.go
43
process.go
@@ -18,7 +18,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
var errCtx error
|
var errCtx []error
|
||||||
|
|
||||||
record.UpstreamEndpoint = upstream.Endpoint
|
record.UpstreamEndpoint = upstream.Endpoint
|
||||||
record.UpstreamSK = upstream.SK
|
record.UpstreamSK = upstream.SK
|
||||||
@@ -52,7 +52,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
|||||||
// read request body
|
// read request body
|
||||||
inBody, err = io.ReadAll(in.Body)
|
inBody, err = io.ReadAll(in.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCtx = errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body " + err.Error())
|
errCtx = append(errCtx, errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body "+err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +62,29 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
|||||||
// record if parse success
|
// record if parse success
|
||||||
if requestBodyOK == nil {
|
if requestBodyOK == nil {
|
||||||
record.Model = requestBody.Model
|
record.Model = requestBody.Model
|
||||||
|
// check allow list
|
||||||
|
if len(upstream.Allow) > 0 {
|
||||||
|
isAllow := false
|
||||||
|
for _, allow := range upstream.Allow {
|
||||||
|
if allow == requestBody.Model {
|
||||||
|
isAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isAllow {
|
||||||
|
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model not allowed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check block list
|
||||||
|
if len(upstream.Deny) > 0 {
|
||||||
|
for _, deny := range upstream.Deny {
|
||||||
|
if deny == requestBody.Model {
|
||||||
|
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model denied"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set timeout, default is 60 second
|
// set timeout, default is 60 second
|
||||||
@@ -82,10 +105,12 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
|||||||
time.Sleep(timeout)
|
time.Sleep(timeout)
|
||||||
if !haveResponse {
|
if !haveResponse {
|
||||||
log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout)
|
log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout)
|
||||||
errCtx = errors.New("timeout")
|
errTimeout := errors.New("[proxy.timeout]: Timeout upstream")
|
||||||
|
errCtx = append(errCtx, errTimeout)
|
||||||
if shouldResponse {
|
if shouldResponse {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
c.AbortWithError(502, errCtx)
|
sendCORSHeaders(c)
|
||||||
|
c.AbortWithError(502, errTimeout)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
@@ -150,13 +175,16 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
|||||||
record.ResponseTime = time.Now().Sub(record.CreatedAt)
|
record.ResponseTime = time.Now().Sub(record.CreatedAt)
|
||||||
log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint)
|
log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint)
|
||||||
|
|
||||||
errCtx = err
|
errCtx = append(errCtx, err)
|
||||||
|
|
||||||
// abort to error handle
|
// abort to error handle
|
||||||
if shouldResponse {
|
if shouldResponse {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
for _, err := range errCtx {
|
||||||
c.AbortWithError(502, err)
|
c.AbortWithError(502, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("[proxy.errorHandler]: response is", r.Response)
|
log.Println("[proxy.errorHandler]: response is", r.Response)
|
||||||
|
|
||||||
@@ -180,10 +208,11 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// return context error
|
// return context error
|
||||||
if errCtx != nil {
|
if len(errCtx) > 0 {
|
||||||
|
log.Println("[proxy.serve]: error from ServeHTTP:", errCtx)
|
||||||
// fix inrequest body
|
// fix inrequest body
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||||
return errCtx
|
return errCtx[len(errCtx)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ type OPENAI_UPSTREAM struct {
|
|||||||
SK string `yaml:"sk"`
|
SK string `yaml:"sk"`
|
||||||
Endpoint string `yaml:"endpoint"`
|
Endpoint string `yaml:"endpoint"`
|
||||||
Timeout int64 `yaml:"timeout"`
|
Timeout int64 `yaml:"timeout"`
|
||||||
|
Allow []string `yaml:"allow"`
|
||||||
|
Deny []string `yaml:"deny"`
|
||||||
URL *url.URL
|
URL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user