refactor: pipe the read and write process

this refactor simplify the process logic and fix several bugs and
performance issue.

bug fixed:
- cors headers not being sent in some situation
performance:
- perform upstream reqeust while clien is uploading content
This commit is contained in:
2024-05-27 14:47:00 +08:00
parent 45bba95f5d
commit 495f32610b
2 changed files with 157 additions and 248 deletions

96
main.go
View File

@@ -1,9 +1,11 @@
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
@@ -87,7 +89,7 @@ func main() {
}
errText := strings.Join(c.Errors.Errors(), "\n")
c.JSON(-1, gin.H{
"error": errText,
"openai-api-route error": errText,
})
})
@@ -101,6 +103,7 @@ func main() {
})
engine.POST("/v1/*any", func(c *gin.Context) {
var err error
hostname, _ := os.Hostname()
if config.Hostname != "" {
hostname = config.Hostname
@@ -123,33 +126,56 @@ func main() {
}
log.Println("Received authorization '" + authorization + "'")
for index, upstream := range config.Upstreams {
availUpstreams := make([]OPENAI_UPSTREAM, 0)
for _, upstream := range config.Upstreams {
if upstream.SK == "" {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) '%s'", upstream.SK))
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
continue
}
shouldResponse := index == len(config.Upstreams)-1
// check authorization header
if !*noauth && !upstream.Noauth {
if checkAuth(authorization, upstream.Authorization) != nil {
if shouldResponse {
c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
c.AbortWithError(403, fmt.Errorf("[processRequest.begin]: wrong authorization header"))
}
log.Println("[auth] Authorization header check failed for", upstream.SK, authorization)
continue
}
log.Println("[auth] Authorization header check pass for", upstream.SK, authorization)
}
if len(config.Upstreams) == 1 {
availUpstreams = append(availUpstreams, upstream)
}
if len(availUpstreams) == 0 {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
}
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
wrapedBody := false
for index, _upstream := range availUpstreams {
// copy
upstream := _upstream
record.UpstreamEndpoint = upstream.Endpoint
record.UpstreamSK = upstream.SK
shouldResponse := index == len(config.Upstreams)-1
if len(availUpstreams) == 1 {
// [todo] copy problem
upstream.Timeout = 120
}
// buffer for incoming request
if !wrapedBody {
log.Println("[processRequest.begin]: wrap request body")
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
wrapedBody = true
} else {
log.Println("[processRequest.begin]: reuse request body")
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
}
if upstream.Type == "replicate" {
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
} else if upstream.Type == "openai" {
@@ -158,19 +184,39 @@ func main() {
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
}
if err != nil {
if err == http.ErrAbortHandler {
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
log.Println(abortErr)
record.Response += abortErr
record.Status = 500
break
}
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err)
continue
if err == nil {
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
break
}
break
if err == http.ErrAbortHandler {
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
log.Println(abortErr)
record.Response += abortErr
record.Status = 500
break
}
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
// error process, break
if shouldResponse {
c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
c.AbortWithError(500, err)
}
}
// parse and record request body
requestBodyBytes := bufIO.Bytes()
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
record.Body = string(requestBodyBytes)
}
requestBody, err := ParseRequestBody(requestBodyBytes)
if err != nil {
log.Println("[processRequest.done]: Error to parse request body:", err)
} else {
record.Model = requestBody.Model
}
log.Println("[final]: Record result:", record.Status, record.Response)