add allow & deny model list, fix cors on error
This commit is contained in:
45
process.go
45
process.go
@@ -18,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
var errCtx error
|
||||
var errCtx []error
|
||||
|
||||
record.UpstreamEndpoint = upstream.Endpoint
|
||||
record.UpstreamSK = upstream.SK
|
||||
@@ -52,7 +52,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
||||
// read request body
|
||||
inBody, err = io.ReadAll(in.Body)
|
||||
if err != nil {
|
||||
errCtx = errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body " + err.Error())
|
||||
errCtx = append(errCtx, errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,6 +62,29 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
||||
// record if parse success
|
||||
if requestBodyOK == nil {
|
||||
record.Model = requestBody.Model
|
||||
// check allow list
|
||||
if len(upstream.Allow) > 0 {
|
||||
isAllow := false
|
||||
for _, allow := range upstream.Allow {
|
||||
if allow == requestBody.Model {
|
||||
isAllow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAllow {
|
||||
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model not allowed"))
|
||||
return
|
||||
}
|
||||
}
|
||||
// check block list
|
||||
if len(upstream.Deny) > 0 {
|
||||
for _, deny := range upstream.Deny {
|
||||
if deny == requestBody.Model {
|
||||
errCtx = append(errCtx, errors.New("[proxy.rewrite]: model denied"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set timeout, default is 60 second
|
||||
@@ -82,10 +105,12 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
||||
time.Sleep(timeout)
|
||||
if !haveResponse {
|
||||
log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout)
|
||||
errCtx = errors.New("timeout")
|
||||
errTimeout := errors.New("[proxy.timeout]: Timeout upstream")
|
||||
errCtx = append(errCtx, errTimeout)
|
||||
if shouldResponse {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithError(502, errCtx)
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(502, errTimeout)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
@@ -150,12 +175,15 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
||||
record.ResponseTime = time.Now().Sub(record.CreatedAt)
|
||||
log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint)
|
||||
|
||||
errCtx = err
|
||||
errCtx = append(errCtx, err)
|
||||
|
||||
// abort to error handle
|
||||
if shouldResponse {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithError(502, err)
|
||||
sendCORSHeaders(c)
|
||||
for _, err := range errCtx {
|
||||
c.AbortWithError(502, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("[proxy.errorHandler]: response is", r.Response)
|
||||
@@ -180,10 +208,11 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
|
||||
}
|
||||
|
||||
// return context error
|
||||
if errCtx != nil {
|
||||
if len(errCtx) > 0 {
|
||||
log.Println("[proxy.serve]: error from ServeHTTP:", errCtx)
|
||||
// fix inrequest body
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||
return errCtx
|
||||
return errCtx[len(errCtx)-1]
|
||||
}
|
||||
|
||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||
|
||||
Reference in New Issue
Block a user