diff --git a/main.go b/main.go index dcf1d74..07a07b8 100644 --- a/main.go +++ b/main.go @@ -33,6 +33,11 @@ func main() { log.Fatal("Failed to connect to database") } + // load all upstreams + upstreams := make([]OPENAI_UPSTREAM, 0) + db.Find(&upstreams) + log.Println("Load upstreams number:", len(upstreams)) + err = initconfig(db) if err != nil { log.Fatal(err) @@ -114,48 +119,30 @@ func main() { } } - // get load balance policy - policy := ConfigKV{Value: "main"} - db.Take(&policy, "key = ?", "policy") - log.Println("policy is", policy.Value) + for index, upstream := range upstreams { + if upstream.Endpoint == "" || upstream.SK == "" { + c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) + continue + } - upstream := OPENAI_UPSTREAM{} + shouldResponse := index == len(upstreams)-1 - // choose openai upstream - switch policy.Value { - case "main": - db.Order("failed_count, success_count desc").First(&upstream) - case "random": - // randomly select one upstream - db.Order("random()").Take(&upstream) - case "random_available": - // randomly select one non-failed upstream - db.Where("failed_count = ?", 0).Order("random()").Take(&upstream) - case "round_robin": - // iterates each upstream - db.Order("last_call_success_time").First(&upstream) - case "round_robin_available": - db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream) - default: - c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value)) + err = processRequest(c, &upstream, &record, shouldResponse) + if err != nil { + log.Println("Error from upstream, should retry", upstream.SK, err) + continue + } + + break } - // do check - log.Println("upstream is", upstream.SK, upstream.Endpoint) - if upstream.Endpoint == "" || upstream.SK == "" { - c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value)) - return - } - - processRequest(c, &upstream, &record) - log.Println("Record result:", record.Status, record.Response) record.ElapsedTime = time.Now().Sub(record.CreatedAt) if db.Create(&record).Error != nil { log.Println("Error to save record:", record) } if record.Status != 200 && record.Response != "context canceled" { - errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, upstream.Endpoint, record.Status, record.Response) + errMessage := fmt.Sprintf("IP: %s request all upstreams error %d with %s", record.IP, record.Status, record.Response) go sendFeishuMessage(errMessage) go sendMatrixMessage(errMessage) } diff --git a/process.go b/process.go index 35f4866..26a5cb7 100644 --- a/process.go +++ b/process.go @@ -15,9 +15,13 @@ import ( "github.com/gin-gonic/gin" ) -func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) error { +func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { + var errCtx error record.UpstreamID = upstream.ID + record.Response = "" + record.Authorization = upstream.SK + // [TODO] record request body // reverse proxy remote, err := url.Parse(upstream.Endpoint) @@ -27,21 +31,22 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e } proxy := httputil.NewSingleHostReverseProxy(remote) proxy.Director = nil + var inBody []byte proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) { in := proxyRequest.In out := proxyRequest.Out // read request body - body, err := io.ReadAll(in.Body) + inBody, err = io.ReadAll(in.Body) if err != nil { c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error())) return } // record chat message from user - record.Body = string(body) + record.Body = string(inBody) - out.Body = io.NopCloser(bytes.NewReader(body)) + out.Body = io.NopCloser(bytes.NewReader(inBody)) out.Host = remote.Host out.URL.Scheme = remote.Scheme @@ -60,6 +65,10 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e var contentType string proxy.ModifyResponse = func(r *http.Response) error { record.Status = r.StatusCode + if !shouldResponse && r.StatusCode != 200 { + log.Println("upstream return not 200 and should not response", r.StatusCode) + return errors.New("upstream return not 200 and should not response") + } r.Header.Del("Access-Control-Allow-Origin") r.Header.Del("Access-Control-Allow-Methods") r.Header.Del("Access-Control-Allow-Headers") @@ -87,6 +96,8 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e log.Println("debug", r) + errCtx = err + // abort to error handle c.AbortWithError(502, err) @@ -104,14 +115,14 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record) e } - func() { - defer func() { - if err := recover(); err != nil { - log.Println("Panic recover :", err) - } - }() - proxy.ServeHTTP(c.Writer, c.Request) - }() + proxy.ServeHTTP(c.Writer, c.Request) + + // return context error + if errCtx != nil { + // fix inrequest body + c.Request.Body = io.NopCloser(bytes.NewReader(inBody)) + return errCtx + } resp, err := io.ReadAll(io.NopCloser(&buf)) if err != nil {