refactor: pipe the read and write process
this refactor simplify the process logic and fix several bugs and performance issue. bug fixed: - cors headers not being sent in some situation performance: - perform upstream reqeust while clien is uploading content
This commit is contained in:
96
main.go
96
main.go
@@ -1,9 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -87,7 +89,7 @@ func main() {
|
||||
}
|
||||
errText := strings.Join(c.Errors.Errors(), "\n")
|
||||
c.JSON(-1, gin.H{
|
||||
"error": errText,
|
||||
"openai-api-route error": errText,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -101,6 +103,7 @@ func main() {
|
||||
})
|
||||
|
||||
engine.POST("/v1/*any", func(c *gin.Context) {
|
||||
var err error
|
||||
hostname, _ := os.Hostname()
|
||||
if config.Hostname != "" {
|
||||
hostname = config.Hostname
|
||||
@@ -123,33 +126,56 @@ func main() {
|
||||
}
|
||||
log.Println("Received authorization '" + authorization + "'")
|
||||
|
||||
for index, upstream := range config.Upstreams {
|
||||
availUpstreams := make([]OPENAI_UPSTREAM, 0)
|
||||
for _, upstream := range config.Upstreams {
|
||||
if upstream.SK == "" {
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) '%s'", upstream.SK))
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
|
||||
continue
|
||||
}
|
||||
|
||||
shouldResponse := index == len(config.Upstreams)-1
|
||||
|
||||
// check authorization header
|
||||
if !*noauth && !upstream.Noauth {
|
||||
if checkAuth(authorization, upstream.Authorization) != nil {
|
||||
if shouldResponse {
|
||||
c.Header("Content-Type", "application/json")
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(403, fmt.Errorf("[processRequest.begin]: wrong authorization header"))
|
||||
}
|
||||
log.Println("[auth] Authorization header check failed for", upstream.SK, authorization)
|
||||
continue
|
||||
}
|
||||
log.Println("[auth] Authorization header check pass for", upstream.SK, authorization)
|
||||
}
|
||||
|
||||
if len(config.Upstreams) == 1 {
|
||||
availUpstreams = append(availUpstreams, upstream)
|
||||
}
|
||||
|
||||
if len(availUpstreams) == 0 {
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
|
||||
}
|
||||
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
|
||||
|
||||
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
wrapedBody := false
|
||||
|
||||
for index, _upstream := range availUpstreams {
|
||||
|
||||
// copy
|
||||
upstream := _upstream
|
||||
record.UpstreamEndpoint = upstream.Endpoint
|
||||
record.UpstreamSK = upstream.SK
|
||||
|
||||
shouldResponse := index == len(config.Upstreams)-1
|
||||
|
||||
if len(availUpstreams) == 1 {
|
||||
// [todo] copy problem
|
||||
upstream.Timeout = 120
|
||||
}
|
||||
|
||||
// buffer for incoming request
|
||||
if !wrapedBody {
|
||||
log.Println("[processRequest.begin]: wrap request body")
|
||||
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
|
||||
wrapedBody = true
|
||||
} else {
|
||||
log.Println("[processRequest.begin]: reuse request body")
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
|
||||
}
|
||||
|
||||
if upstream.Type == "replicate" {
|
||||
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
|
||||
} else if upstream.Type == "openai" {
|
||||
@@ -158,19 +184,39 @@ func main() {
|
||||
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == http.ErrAbortHandler {
|
||||
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
|
||||
log.Println(abortErr)
|
||||
record.Response += abortErr
|
||||
record.Status = 500
|
||||
break
|
||||
}
|
||||
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err)
|
||||
continue
|
||||
if err == nil {
|
||||
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
if err == http.ErrAbortHandler {
|
||||
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
|
||||
log.Println(abortErr)
|
||||
record.Response += abortErr
|
||||
record.Status = 500
|
||||
break
|
||||
}
|
||||
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
|
||||
|
||||
// error process, break
|
||||
if shouldResponse {
|
||||
c.Header("Content-Type", "application/json")
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, err)
|
||||
}
|
||||
}
|
||||
|
||||
// parse and record request body
|
||||
requestBodyBytes := bufIO.Bytes()
|
||||
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
|
||||
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
|
||||
record.Body = string(requestBodyBytes)
|
||||
}
|
||||
requestBody, err := ParseRequestBody(requestBodyBytes)
|
||||
if err != nil {
|
||||
log.Println("[processRequest.done]: Error to parse request body:", err)
|
||||
} else {
|
||||
record.Model = requestBody.Model
|
||||
}
|
||||
|
||||
log.Println("[final]: Record result:", record.Status, record.Response)
|
||||
|
||||
Reference in New Issue
Block a user