Compare commits

...

9 Commits

Author SHA1 Message Date
d9a42842b2 1m timeout for non stream request 2023-11-01 15:55:54 +08:00
5a78c61e5f timeout for out going request 2023-11-01 15:25:59 +08:00
eced585361 Merge remote-tracking branch 'shanghai/master' 2023-10-31 18:40:33 +08:00
98a15052a2 fix: should response 2023-10-31 18:35:29 +08:00
7572ecf19b fix: retry with 502 2023-10-31 18:34:58 +08:00
9f2bb46233 update print 2023-10-31 18:34:41 +08:00
d8948d065a add vscode debug 2023-10-31 18:10:09 +08:00
2c75c392a8 loop through upstreams 2023-10-31 18:10:01 +08:00
acc153ddca 拆分 process 2023-10-31 10:13:20 +08:00
5 changed files with 279 additions and 183 deletions

15
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${fileDirname}"
}
]
}

199
main.go
View File

@@ -1,16 +1,9 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
@@ -41,6 +34,11 @@ func main() {
log.Fatal("Failed to connect to database")
}
// load all upstreams
upstreams := make([]OPENAI_UPSTREAM, 0)
db.Find(&upstreams)
log.Println("Load upstreams number:", len(upstreams))
err = initconfig(db)
if err != nil {
log.Fatal(err)
@@ -127,186 +125,21 @@ func main() {
}
}
// get load balance policy
policy := ConfigKV{Value: "main"}
db.Take(&policy, "key = ?", "policy")
log.Println("policy is", policy.Value)
for index, upstream := range upstreams {
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
continue
}
upstream := OPENAI_UPSTREAM{}
shouldResponse := index == len(upstreams)-1
// choose openai upstream
switch policy.Value {
case "main":
db.Order("failed_count, success_count desc").First(&upstream)
case "random":
// randomly select one upstream
db.Order("random()").Take(&upstream)
case "random_available":
// randomly select one non-failed upstream
db.Where("failed_count = ?", 0).Order("random()").Take(&upstream)
case "round_robin":
// iterates each upstream
db.Order("last_call_success_time").First(&upstream)
case "round_robin_available":
db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream)
default:
c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value))
}
// do check
log.Println("upstream is", upstream.SK, upstream.Endpoint)
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value))
return
}
record.UpstreamID = upstream.ID
// reverse proxy
remote, err := url.Parse(upstream.Endpoint)
if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = nil
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In
out := proxyRequest.Out
// read request body
body, err := io.ReadAll(in.Body)
err = processRequest(c, &upstream, &record, shouldResponse)
if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
return
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
continue
}
// record chat message from user
record.Body = string(body)
out.Body = io.NopCloser(bytes.NewReader(body))
out.Host = remote.Host
out.URL.Scheme = remote.Scheme
out.URL.Host = remote.Host
out.URL.Path = in.URL.Path
out.Header = http.Header{}
out.Header.Set("Host", remote.Host)
if upstream.SK == "asis" {
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
out.Header.Set("Authorization", "Bearer "+upstream.SK)
}
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
var buf bytes.Buffer
var contentType string
proxy.ModifyResponse = func(r *http.Response) error {
record.Status = r.StatusCode
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body)
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
return errors.New(record.Response)
}
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
record.Status = r.StatusCode
return fmt.Errorf(record.Response)
}
// count success
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
contentType = r.Header.Get("content-type")
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Println("Error", err, upstream.SK, upstream.Endpoint)
log.Println("debug", r)
// abort to error handle
c.AbortWithError(502, err)
log.Println("response is", r.Response)
if record.Status == 0 {
record.Status = 502
}
if record.Response == "" {
record.Response = err.Error()
}
if r.Response != nil {
record.Status = r.Response.StatusCode
}
}
func() {
defer func() {
if err := recover(); err != nil {
log.Println("Panic recover :", err)
}
}()
proxy.ServeHTTP(c.Writer, c.Request)
}()
resp, err := io.ReadAll(io.NopCloser(&buf))
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
log.Println(record.Response)
} else {
// record response
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
for _, line := range strings.Split(string(resp), "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(resp, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return
}
record.Response = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
}
}
if len(record.Body) > 1024*512 {
record.Body = ""
break
}
log.Println("Record result:", record.Status, record.Response)
@@ -315,7 +148,7 @@ func main() {
log.Println("Error to save record:", record)
}
if record.Status != 200 && record.Response != "context canceled" {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, upstream.Endpoint, record.Status, record.Response)
errMessage := fmt.Sprintf("IP: %s request all upstreams error %d with %s", record.IP, record.Status, record.Response)
go sendFeishuMessage(errMessage)
go sendMatrixMessage(errMessage)
}

215
process.go Normal file
View File

@@ -0,0 +1,215 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/context"
)
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error
record.UpstreamID = upstream.ID
record.Response = ""
record.Authorization = upstream.SK
// [TODO] record request body
// reverse proxy
remote, err := url.Parse(upstream.Endpoint)
if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return err
}
haveResponse := false
proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = nil
var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In
ctx, cancel := context.WithCancel(context.Background())
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
out := proxyRequest.Out
// read request body
inBody, err = io.ReadAll(in.Body)
if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
return
}
// record chat message from user
record.Body = string(inBody)
requestBody, requestBodyOK := ParseRequestBody(inBody)
// set timeout, default is 5 second
timeout := 5 * time.Second
if upstream.Timeout > 0 {
// convert upstream.Timeout(second) to nanosecond
timeout = time.Duration(upstream.Timeout) * time.Second
}
if requestBodyOK == nil && !requestBody.Stream {
timeout = 60 * time.Second
}
// timeout out request
go func() {
time.Sleep(timeout)
if !haveResponse {
log.Println("Timeout upstream", upstream.Endpoint)
errCtx = errors.New("timeout")
cancel()
}
}()
out.Body = io.NopCloser(bytes.NewReader(inBody))
out.Host = remote.Host
out.URL.Scheme = remote.Scheme
out.URL.Host = remote.Host
out.URL.Path = in.URL.Path
out.Header = http.Header{}
out.Header.Set("Host", remote.Host)
if upstream.SK == "asis" {
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
out.Header.Set("Authorization", "Bearer "+upstream.SK)
}
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
var buf bytes.Buffer
var contentType string
proxy.ModifyResponse = func(r *http.Response) error {
haveResponse = true
log.Println("haveResponse set to true")
record.Status = r.StatusCode
if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response")
}
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body)
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
return errors.New(record.Response)
}
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
record.Status = r.StatusCode
return fmt.Errorf(record.Response)
}
// count success
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
contentType = r.Header.Get("content-type")
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Println("Error", err, upstream.SK, upstream.Endpoint)
log.Println("debug", r)
errCtx = err
// abort to error handle
if shouldResponse {
c.AbortWithError(502, err)
}
log.Println("response is", r.Response)
if record.Status == 0 {
record.Status = 502
}
if record.Response == "" {
record.Response = err.Error()
}
if r.Response != nil {
record.Status = r.Response.StatusCode
}
}
proxy.ServeHTTP(c.Writer, c.Request)
// return context error
if errCtx != nil {
// fix inrequest body
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
return errCtx
}
resp, err := io.ReadAll(io.NopCloser(&buf))
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
log.Println(record.Response)
} else {
// record response
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
for _, line := range strings.Split(string(resp), "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(resp, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return nil
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return nil
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return nil
}
record.Response = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
}
}
if len(record.Body) > 1024*512 {
record.Body = ""
}
return nil
}

32
request_body.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"encoding/json"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type RequestBody struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
TopP int64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
}
func ParseRequestBody(data []byte) (RequestBody, error) {
ret := RequestBody{}
var requestBody RequestBody
err := json.Unmarshal(data, &requestBody)
if err != nil {
return ret, err
}
return requestBody, nil
}

View File

@@ -9,4 +9,5 @@ type OPENAI_UPSTREAM struct {
gorm.Model
SK string `gorm:"index:idx_sk_endpoint,unique"` // key
Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint
Timeout int64 // timeout in seconds
}