diff --git a/main.go b/main.go index 37bb864..3224af5 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,11 @@ package main import ( + "bytes" "encoding/json" "flag" "fmt" + "io" "log" "net/http" "os" @@ -87,7 +89,7 @@ func main() { } errText := strings.Join(c.Errors.Errors(), "\n") c.JSON(-1, gin.H{ - "error": errText, + "openai-api-route error": errText, }) }) @@ -101,6 +103,7 @@ func main() { }) engine.POST("/v1/*any", func(c *gin.Context) { + var err error hostname, _ := os.Hostname() if config.Hostname != "" { hostname = config.Hostname @@ -123,33 +126,56 @@ func main() { } log.Println("Received authorization '" + authorization + "'") - for index, upstream := range config.Upstreams { + availUpstreams := make([]OPENAI_UPSTREAM, 0) + for _, upstream := range config.Upstreams { if upstream.SK == "" { sendCORSHeaders(c) - c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) '%s'", upstream.SK)) + c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK)) continue } - shouldResponse := index == len(config.Upstreams)-1 - - // check authorization header if !*noauth && !upstream.Noauth { if checkAuth(authorization, upstream.Authorization) != nil { - if shouldResponse { - c.Header("Content-Type", "application/json") - sendCORSHeaders(c) - c.AbortWithError(403, fmt.Errorf("[processRequest.begin]: wrong authorization header")) - } - log.Println("[auth] Authorization header check failed for", upstream.SK, authorization) continue } - log.Println("[auth] Authorization header check pass for", upstream.SK, authorization) } - if len(config.Upstreams) == 1 { + availUpstreams = append(availUpstreams, upstream) + } + + if len(availUpstreams) == 0 { + sendCORSHeaders(c) + c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token")) + } + log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams)) + + bufIO := bytes.NewBuffer(make([]byte, 0, 1024)) + wrapedBody := false + + for index, _upstream := range availUpstreams { + + // copy + upstream := _upstream + record.UpstreamEndpoint = upstream.Endpoint + record.UpstreamSK = upstream.SK + + shouldResponse := index == len(config.Upstreams)-1 + + if len(availUpstreams) == 1 { + // [todo] copy problem upstream.Timeout = 120 } + // buffer for incoming request + if !wrapedBody { + log.Println("[processRequest.begin]: wrap request body") + c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO)) + wrapedBody = true + } else { + log.Println("[processRequest.begin]: reuse request body") + c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes())) + } + if upstream.Type == "replicate" { err = processReplicateRequest(c, &upstream, &record, shouldResponse) } else if upstream.Type == "openai" { @@ -158,19 +184,39 @@ func main() { err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type) } - if err != nil { - if err == http.ErrAbortHandler { - abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here" - log.Println(abortErr) - record.Response += abortErr - record.Status = 500 - break - } - log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err) - continue + if err == nil { + log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint) + break } - break + if err == http.ErrAbortHandler { + abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here" + log.Println(abortErr) + record.Response += abortErr + record.Status = 500 + break + } + log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse) + + // error process, break + if shouldResponse { + c.Header("Content-Type", "application/json") + sendCORSHeaders(c) + c.AbortWithError(500, err) + } + } + + // parse and record request body + requestBodyBytes := bufIO.Bytes() + if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") || + strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) { + record.Body = string(requestBodyBytes) + } + requestBody, err := ParseRequestBody(requestBodyBytes) + if err != nil { + log.Println("[processRequest.done]: Error to parse request body:", err) + } else { + record.Model = requestBody.Model } log.Println("[final]: Record result:", record.Status, record.Response) diff --git a/process.go b/process.go index 86bd07d..fdac7ea 100644 --- a/process.go +++ b/process.go @@ -8,22 +8,14 @@ import ( "io" "log" "net/http" - "net/http/httputil" "net/url" "strings" "time" "github.com/gin-gonic/gin" - "golang.org/x/net/context" ) func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { - var errCtx []error - - record.UpstreamEndpoint = upstream.Endpoint - record.UpstreamSK = upstream.SK - record.Response = "" - // [TODO] record request body // reverse proxy remote, err := url.Parse(upstream.Endpoint) @@ -32,231 +24,102 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s } path := strings.TrimPrefix(c.Request.URL.Path, "/v1") - // recoognize whisper url remote.Path = upstream.URL.Path + path log.Println("[proxy.begin]:", remote) log.Println("[proxy.begin]: shouldResposne:", shouldResponse) - haveResponse := false + client := &http.Client{} + request := &http.Request{} + request.ContentLength = c.Request.ContentLength + request.Method = c.Request.Method + request.URL = remote - proxy := httputil.NewSingleHostReverseProxy(remote) - proxy.Director = nil - var inBody []byte - proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) { - - in := proxyRequest.In - - ctx, cancel := context.WithCancel(context.Background()) - proxyRequest.Out = proxyRequest.Out.WithContext(ctx) - - out := proxyRequest.Out - - // read request body - inBody, err = io.ReadAll(in.Body) - if err != nil { - errCtx = append(errCtx, errors.New("[proxy.rewrite]: reverse proxy middleware failed to read request body "+err.Error())) - return - } - - // record chat message from user - requestBody, requestBodyOK := ParseRequestBody(inBody) - // record if parse success - if requestBodyOK == nil && requestBody.Model != "" { - record.Model = requestBody.Model - record.Body = string(inBody) - } - - // check allow list - if len(upstream.Allow) > 0 { - isAllow := false - for _, allow := range upstream.Allow { - if allow == record.Model { - isAllow = true - break - } - } - if !isAllow { - errCtx = append(errCtx, errors.New("[proxy.rewrite]: model '"+record.Model+"' not allowed")) - return - } - } - // check block list - if len(upstream.Deny) > 0 { - for _, deny := range upstream.Deny { - if deny == record.Model { - errCtx = append(errCtx, errors.New("[proxy.rewrite]: model '"+record.Model+"' denied")) - return - } - } - } - - // set timeout, default is 60 second - timeout := time.Duration(upstream.Timeout) * time.Second - if requestBodyOK == nil && requestBody.Stream { - timeout = time.Duration(upstream.StreamTimeout) * time.Second - } - - // timeout out request - go func() { - time.Sleep(timeout) - if !haveResponse { - log.Println("[proxy.timeout]: Timeout upstream", upstream.Endpoint, timeout) - errTimeout := errors.New("[proxy.timeout]: Timeout upstream") - errCtx = append(errCtx, errTimeout) - if shouldResponse { - c.Header("Content-Type", "application/json") - sendCORSHeaders(c) - c.AbortWithError(502, errTimeout) - } - cancel() - } - }() - - out.Body = io.NopCloser(bytes.NewReader(inBody)) - - out.Host = remote.Host - out.URL.Scheme = remote.Scheme - out.URL.Host = remote.Host - - if !upstream.KeepHeader { - out.Header = http.Header{} - } - out.Header.Set("Host", remote.Host) - if upstream.SK == "asis" { - out.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } else { - out.Header.Set("Authorization", "Bearer "+upstream.SK) - } - out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - } - var buf bytes.Buffer - var contentType string - proxy.ModifyResponse = func(r *http.Response) error { - haveResponse = true - record.ResponseTime = time.Since(record.CreatedAt) - record.Status = r.StatusCode - - // remove response's cors headers - r.Header.Del("Access-Control-Allow-Origin") - r.Header.Del("Access-Control-Allow-Methods") - r.Header.Del("Access-Control-Allow-Headers") - r.Header.Del("access-control-allow-origin") - r.Header.Del("access-control-allow-methods") - r.Header.Del("access-control-allow-headers") - - if !shouldResponse && r.StatusCode != 200 { - log.Println("[proxy.modifyResponse]: upstream return not 200 and should not response", r.StatusCode) - return errors.New("upstream return not 200 and should not response") - } - - if r.StatusCode != 200 { - body, err := io.ReadAll(r.Body) - if err != nil { - errRet := errors.New("[proxy.modifyResponse]: failed to read response from upstream " + err.Error()) - return errRet - } - errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", r.Status, string(body)) - log.Println(errRet) - record.Status = r.StatusCode - return errRet - } - // handle reverse proxy cors header if upstream do not set that - sendCORSHeaders(c) - // count success - r.Body = io.NopCloser(io.TeeReader(r.Body, &buf)) - contentType = r.Header.Get("content-type") - return nil - } - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - haveResponse = true - record.ResponseTime = time.Since(record.CreatedAt) - log.Println("[proxy.errorHandler]", err, upstream.SK, upstream.Endpoint, errCtx) - - errCtx = append(errCtx, err) - - // abort to error handle - if shouldResponse { - c.Header("Content-Type", "application/json") - sendCORSHeaders(c) - for _, err := range errCtx { - c.AbortWithError(502, err) - } - } - - log.Println("[proxy.errorHandler]: response is", r.Response) - - if record.Status == 0 { - record.Status = 502 - } - record.Response += "[proxy.ErrorHandler]: " + err.Error() - if r.Response != nil { - record.Status = r.Response.StatusCode - } - - } - - err = ServeHTTP(proxy, c.Writer, c.Request) - if err != nil { - log.Println("[proxy.serve]: error from ServeHTTP:", err) - // panic means client has abort the http connection - // since the connection is lost, we return - // and the reverse process should not try the next upsteam - return http.ErrAbortHandler - } - - // return context error - if len(errCtx) > 0 { - log.Println("[proxy.serve]: error from ServeHTTP:", errCtx) - // fix inrequest body - c.Request.Body = io.NopCloser(bytes.NewReader(inBody)) - return errCtx[len(errCtx)-1] - } - - resp, err := io.ReadAll(io.NopCloser(&buf)) - if err != nil { - record.Response = "failed to read response from upstream " + err.Error() - log.Println(record.Response) + // process header + if upstream.KeepHeader { + request.Header = c.Request.Header } else { + request.Header = http.Header{} + } - // record response - // stream mode - if strings.HasPrefix(contentType, "text/event-stream") { - for _, line := range strings.Split(string(resp), "\n") { - chunk := StreamModeChunk{} - line = strings.TrimPrefix(line, "data:") - line = strings.TrimSpace(line) - if line == "" { - continue - } + // process header authorization + if upstream.SK == "asis" { + request.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } else { + request.Header.Set("Authorization", "Bearer "+upstream.SK) + } + request.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + request.Header.Set("Host", remote.Host) + request.Header.Set("Content-Length", c.Request.Header.Get("Content-Length")) - err := json.Unmarshal([]byte(line), &chunk) - if err != nil { - log.Println("[proxy.parseChunkError]:", err) - continue - } + request.Body = c.Request.Body - if len(chunk.Choices) == 0 { - continue - } - record.Response += chunk.Choices[0].Delta.Content - } - } else if strings.HasPrefix(contentType, "text") { - record.Response = string(resp) - } else if strings.HasPrefix(contentType, "application/json") { - // fallback record response - if len(resp) < 1024*128 { - record.Response = string(resp) - } - var fetchResp FetchModeResponse - err := json.Unmarshal(resp, &fetchResp) - if err == nil { - if len(fetchResp.Choices) > 0 { - record.Response = fetchResp.Choices[0].Message.Content - } - } - } else { - log.Println("[proxy.record]: Unknown content type", contentType) + resp, err := client.Do(request) + if err != nil { + body := []byte{} + if resp != nil && resp.Body != nil { + body, _ = io.ReadAll(resp.Body) } + return errors.New(err.Error() + " " + string(body)) + } + + defer resp.Body.Close() + + record.Status = resp.StatusCode + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + record.Status = resp.StatusCode + errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", resp.Status, string(body)) + log.Println(errRet) + return errRet + } + + // copy response header + for k, v := range resp.Header { + c.Header(k, v[0]) + } + sendCORSHeaders(c) + + respBodyBuffer := bytes.NewBuffer(make([]byte, 0, 4*1024)) + respBodyTeeReader := io.TeeReader(resp.Body, respBodyBuffer) + record.ResponseTime = time.Since(record.CreatedAt) + io.Copy(c.Writer, respBodyTeeReader) + record.ElapsedTime = time.Since(record.CreatedAt) + + // parse and record response + if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + var fetchResp FetchModeResponse + err := json.NewDecoder(respBodyBuffer).Decode(&fetchResp) + if err == nil { + if len(fetchResp.Choices) > 0 { + record.Response = fetchResp.Choices[0].Message.Content + } + } + } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { + lines := bytes.Split(respBodyBuffer.Bytes(), []byte("\n")) + for _, line := range lines { + line = bytes.TrimSpace(line) + line = bytes.TrimPrefix(line, []byte("data:")) + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + chunk := StreamModeChunk{} + err = json.Unmarshal(line, &chunk) + if err != nil { + log.Println("[proxy.parseChunkError]:", err) + break + } + if len(chunk.Choices) == 0 { + continue + } + record.Response += chunk.Choices[0].Delta.Content + } + } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text") { + body, _ := io.ReadAll(respBodyBuffer) + record.Response = string(body) + } else { + log.Println("[proxy.record]: Unknown content type", resp.Header.Get("Content-Type")) } return nil