loop through upstreams

This commit is contained in:
2023-10-31 18:10:01 +08:00
parent acc153ddca
commit 2c75c392a8
2 changed files with 42 additions and 44 deletions

51
main.go
View File

@@ -33,6 +33,11 @@ func main() {
log.Fatal("Failed to connect to database") log.Fatal("Failed to connect to database")
} }
// load all upstreams
upstreams := make([]OPENAI_UPSTREAM, 0)
db.Find(&upstreams)
log.Println("Load upstreams number:", len(upstreams))
err = initconfig(db) err = initconfig(db)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@@ -114,48 +119,30 @@ func main() {
} }
} }
// get load balance policy for index, upstream := range upstreams {
policy := ConfigKV{Value: "main"} if upstream.Endpoint == "" || upstream.SK == "" {
db.Take(&policy, "key = ?", "policy") c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
log.Println("policy is", policy.Value) continue
}
upstream := OPENAI_UPSTREAM{} shouldResponse := index == len(upstreams)-1
// choose openai upstream err = processRequest(c, &upstream, &record, shouldResponse)
switch policy.Value { if err != nil {
case "main": log.Println("Error from upstream, should retry", upstream.SK, err)
db.Order("failed_count, success_count desc").First(&upstream) continue
case "random": }
// randomly select one upstream
db.Order("random()").Take(&upstream) break
case "random_available":
// randomly select one non-failed upstream
db.Where("failed_count = ?", 0).Order("random()").Take(&upstream)
case "round_robin":
// iterates each upstream
db.Order("last_call_success_time").First(&upstream)
case "round_robin_available":
db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream)
default:
c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value))
} }
// do check
log.Println("upstream is", upstream.SK, upstream.Endpoint)
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value))
return
}
processRequest(c, &upstream, &record)
log.Println("Record result:", record.Status, record.Response) log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt) record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil { if db.Create(&record).Error != nil {
log.Println("Error to save record:", record) log.Println("Error to save record:", record)
} }
if record.Status != 200 && record.Response != "context canceled" { if record.Status != 200 && record.Response != "context canceled" {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, upstream.Endpoint, record.Status, record.Response) errMessage := fmt.Sprintf("IP: %s request all upstreams error %d with %s", record.IP, record.Status, record.Response)
go sendFeishuMessage(errMessage) go sendFeishuMessage(errMessage)
go sendMatrixMessage(errMessage) go sendMatrixMessage(errMessage)
} }

View File

@@ -15,9 +15,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) error { func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error
record.UpstreamID = upstream.ID record.UpstreamID = upstream.ID
record.Response = ""
record.Authorization = upstream.SK
// [TODO] record request body
// reverse proxy // reverse proxy
remote, err := url.Parse(upstream.Endpoint) remote, err := url.Parse(upstream.Endpoint)
@@ -27,21 +31,22 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
} }
proxy := httputil.NewSingleHostReverseProxy(remote) proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = nil proxy.Director = nil
var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) { proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In in := proxyRequest.In
out := proxyRequest.Out out := proxyRequest.Out
// read request body // read request body
body, err := io.ReadAll(in.Body) inBody, err = io.ReadAll(in.Body)
if err != nil { if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error())) c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
return return
} }
// record chat message from user // record chat message from user
record.Body = string(body) record.Body = string(inBody)
out.Body = io.NopCloser(bytes.NewReader(body)) out.Body = io.NopCloser(bytes.NewReader(inBody))
out.Host = remote.Host out.Host = remote.Host
out.URL.Scheme = remote.Scheme out.URL.Scheme = remote.Scheme
@@ -60,6 +65,10 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
var contentType string var contentType string
proxy.ModifyResponse = func(r *http.Response) error { proxy.ModifyResponse = func(r *http.Response) error {
record.Status = r.StatusCode record.Status = r.StatusCode
if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response")
}
r.Header.Del("Access-Control-Allow-Origin") r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods") r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers") r.Header.Del("Access-Control-Allow-Headers")
@@ -87,6 +96,8 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
log.Println("debug", r) log.Println("debug", r)
errCtx = err
// abort to error handle // abort to error handle
c.AbortWithError(502, err) c.AbortWithError(502, err)
@@ -104,14 +115,14 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e
} }
func() { proxy.ServeHTTP(c.Writer, c.Request)
defer func() {
if err := recover(); err != nil { // return context error
log.Println("Panic recover :", err) if errCtx != nil {
} // fix inrequest body
}() c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
proxy.ServeHTTP(c.Writer, c.Request) return errCtx
}() }
resp, err := io.ReadAll(io.NopCloser(&buf)) resp, err := io.ReadAll(io.NopCloser(&buf))
if err != nil { if err != nil {