loop through upstreams
This commit is contained in:
51
main.go
51
main.go
@@ -33,6 +33,11 @@ func main() {
|
|||||||
log.Fatal("Failed to connect to database")
|
log.Fatal("Failed to connect to database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load all upstreams
|
||||||
|
upstreams := make([]OPENAI_UPSTREAM, 0)
|
||||||
|
db.Find(&upstreams)
|
||||||
|
log.Println("Load upstreams number:", len(upstreams))
|
||||||
|
|
||||||
err = initconfig(db)
|
err = initconfig(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@@ -114,48 +119,30 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// get load balance policy
|
for index, upstream := range upstreams {
|
||||||
policy := ConfigKV{Value: "main"}
|
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||||
db.Take(&policy, "key = ?", "policy")
|
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||||
log.Println("policy is", policy.Value)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
upstream := OPENAI_UPSTREAM{}
|
shouldResponse := index == len(upstreams)-1
|
||||||
|
|
||||||
// choose openai upstream
|
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||||
switch policy.Value {
|
if err != nil {
|
||||||
case "main":
|
log.Println("Error from upstream, should retry", upstream.SK, err)
|
||||||
db.Order("failed_count, success_count desc").First(&upstream)
|
continue
|
||||||
case "random":
|
}
|
||||||
// randomly select one upstream
|
|
||||||
db.Order("random()").Take(&upstream)
|
break
|
||||||
case "random_available":
|
|
||||||
// randomly select one non-failed upstream
|
|
||||||
db.Where("failed_count = ?", 0).Order("random()").Take(&upstream)
|
|
||||||
case "round_robin":
|
|
||||||
// iterates each upstream
|
|
||||||
db.Order("last_call_success_time").First(&upstream)
|
|
||||||
case "round_robin_available":
|
|
||||||
db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream)
|
|
||||||
default:
|
|
||||||
c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// do check
|
|
||||||
log.Println("upstream is", upstream.SK, upstream.Endpoint)
|
|
||||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
|
||||||
c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
processRequest(c, &upstream, &record)
|
|
||||||
|
|
||||||
log.Println("Record result:", record.Status, record.Response)
|
log.Println("Record result:", record.Status, record.Response)
|
||||||
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
|
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
|
||||||
if db.Create(&record).Error != nil {
|
if db.Create(&record).Error != nil {
|
||||||
log.Println("Error to save record:", record)
|
log.Println("Error to save record:", record)
|
||||||
}
|
}
|
||||||
if record.Status != 200 && record.Response != "context canceled" {
|
if record.Status != 200 && record.Response != "context canceled" {
|
||||||
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, upstream.Endpoint, record.Status, record.Response)
|
errMessage := fmt.Sprintf("IP: %s request all upstreams error %d with %s", record.IP, record.Status, record.Response)
|
||||||
go sendFeishuMessage(errMessage)
|
go sendFeishuMessage(errMessage)
|
||||||
go sendMatrixMessage(errMessage)
|
go sendMatrixMessage(errMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
35
process.go
35
process.go
@@ -15,9 +15,13 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) error {
|
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
|
var errCtx error
|
||||||
|
|
||||||
record.UpstreamID = upstream.ID
|
record.UpstreamID = upstream.ID
|
||||||
|
record.Response = ""
|
||||||
|
record.Authorization = upstream.SK
|
||||||
|
// [TODO] record request body
|
||||||
|
|
||||||
// reverse proxy
|
// reverse proxy
|
||||||
remote, err := url.Parse(upstream.Endpoint)
|
remote, err := url.Parse(upstream.Endpoint)
|
||||||
@@ -27,21 +31,22 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
|
|||||||
}
|
}
|
||||||
proxy := httputil.NewSingleHostReverseProxy(remote)
|
proxy := httputil.NewSingleHostReverseProxy(remote)
|
||||||
proxy.Director = nil
|
proxy.Director = nil
|
||||||
|
var inBody []byte
|
||||||
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
||||||
in := proxyRequest.In
|
in := proxyRequest.In
|
||||||
out := proxyRequest.Out
|
out := proxyRequest.Out
|
||||||
|
|
||||||
// read request body
|
// read request body
|
||||||
body, err := io.ReadAll(in.Body)
|
inBody, err = io.ReadAll(in.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// record chat message from user
|
// record chat message from user
|
||||||
record.Body = string(body)
|
record.Body = string(inBody)
|
||||||
|
|
||||||
out.Body = io.NopCloser(bytes.NewReader(body))
|
out.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||||
|
|
||||||
out.Host = remote.Host
|
out.Host = remote.Host
|
||||||
out.URL.Scheme = remote.Scheme
|
out.URL.Scheme = remote.Scheme
|
||||||
@@ -60,6 +65,10 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
|
|||||||
var contentType string
|
var contentType string
|
||||||
proxy.ModifyResponse = func(r *http.Response) error {
|
proxy.ModifyResponse = func(r *http.Response) error {
|
||||||
record.Status = r.StatusCode
|
record.Status = r.StatusCode
|
||||||
|
if !shouldResponse && r.StatusCode != 200 {
|
||||||
|
log.Println("upstream return not 200 and should not response", r.StatusCode)
|
||||||
|
return errors.New("upstream return not 200 and should not response")
|
||||||
|
}
|
||||||
r.Header.Del("Access-Control-Allow-Origin")
|
r.Header.Del("Access-Control-Allow-Origin")
|
||||||
r.Header.Del("Access-Control-Allow-Methods")
|
r.Header.Del("Access-Control-Allow-Methods")
|
||||||
r.Header.Del("Access-Control-Allow-Headers")
|
r.Header.Del("Access-Control-Allow-Headers")
|
||||||
@@ -87,6 +96,8 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
|
|||||||
|
|
||||||
log.Println("debug", r)
|
log.Println("debug", r)
|
||||||
|
|
||||||
|
errCtx = err
|
||||||
|
|
||||||
// abort to error handle
|
// abort to error handle
|
||||||
c.AbortWithError(502, err)
|
c.AbortWithError(502, err)
|
||||||
|
|
||||||
@@ -104,14 +115,14 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
// return context error
|
||||||
log.Println("Panic recover :", err)
|
if errCtx != nil {
|
||||||
}
|
// fix inrequest body
|
||||||
}()
|
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
return errCtx
|
||||||
}()
|
}
|
||||||
|
|
||||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user