Compare commits
18 Commits
584103eba3
...
v0.3.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
6aab4fb44b
|
|||
|
b932331bdc
|
|||
|
a2c6fa32ed
|
|||
|
9d8e76bd2d
|
|||
|
ebc04228c5
|
|||
|
3c4c2b5660
|
|||
|
a3fff93f2e
|
|||
|
c90a18d380
|
|||
|
d9a42842b2
|
|||
|
5a78c61e5f
|
|||
|
eced585361
|
|||
|
98a15052a2
|
|||
|
7572ecf19b
|
|||
|
9f2bb46233
|
|||
|
d8948d065a
|
|||
|
2c75c392a8
|
|||
|
acc153ddca
|
|||
|
471627712b
|
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Launch Package",
|
||||||
|
"type": "go",
|
||||||
|
"request": "launch",
|
||||||
|
"mode": "auto",
|
||||||
|
"program": "${fileDirname}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
53
README.md
53
README.md
@@ -1,13 +1,16 @@
|
|||||||
# openai-api-route 文档
|
# openai-api-route 文档
|
||||||
|
|
||||||
这是一个 OpenAI API 负载均衡的简易工具,使用 golang 原生 reverse proxy 方法转发请求到 OpenAI 上游
|
这是一个 OpenAI API 负载均衡的简易工具,使用 golang 原生 reverse proxy 方法转发请求到 OpenAI 上游。遇到上游返回报错或请求超时会自动按顺序选择下一个上游进行重试,直到所有上游都请求失败。
|
||||||
|
|
||||||
功能包括:
|
功能包括:
|
||||||
|
|
||||||
- 更改 Authorization 验证头
|
- 自定义 Authorization 验证头
|
||||||
- 多种负载均衡策略
|
- 支持所有类型的接口 (`/v1/*`)
|
||||||
- 记录完整的请求内容、IP 地址、响应时间以及 GPT 回复文本
|
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
|
||||||
- 上游返回错误时发送 飞书 或 Matrix 消息通知
|
- 按照定义顺序请求 OpenAI 上游
|
||||||
|
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。对于其他请求使用60秒超时。
|
||||||
|
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
|
||||||
|
- 请求出错时发送 飞书 或 Matrix 消息通知
|
||||||
|
|
||||||
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
||||||
|
|
||||||
@@ -24,37 +27,37 @@
|
|||||||
3. 打开终端,并进入到仓库目录中。
|
3. 打开终端,并进入到仓库目录中。
|
||||||
|
|
||||||
4. 在终端中执行以下命令来编译代码:
|
4. 在终端中执行以下命令来编译代码:
|
||||||
|
|
||||||
```
|
```
|
||||||
make
|
make
|
||||||
```
|
```
|
||||||
|
|
||||||
这将会编译代码并生成可执行文件。
|
这将会编译代码并生成可执行文件。
|
||||||
|
|
||||||
5. 编译成功后,您可以直接运行以下命令来启动负载均衡 API:
|
5. 编译成功后,您可以直接运行以下命令来启动负载均衡 API:
|
||||||
|
|
||||||
```
|
```
|
||||||
./openai-api-route
|
./openai-api-route
|
||||||
```
|
```
|
||||||
|
|
||||||
默认情况下,API 将会在本地的 8888 端口进行监听。
|
默认情况下,API 将会在本地的 8888 端口进行监听。
|
||||||
|
|
||||||
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
|
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
|
||||||
|
|
||||||
```
|
```
|
||||||
./openai-api-route -addr 0.0.0.0:8080
|
./openai-api-route -addr 0.0.0.0:8080
|
||||||
```
|
```
|
||||||
|
|
||||||
这将会将监听地址设置为 0.0.0.0:8080。
|
这将会将监听地址设置为 0.0.0.0:8080。
|
||||||
|
|
||||||
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件。
|
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件。
|
||||||
|
|
||||||
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
|
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
|
||||||
|
|
||||||
```
|
```
|
||||||
./openai-api-route -database /path/to/database.db
|
./openai-api-route -database /path/to/database.db
|
||||||
```
|
```
|
||||||
|
|
||||||
这将会将数据库地址设置为 `/path/to/database.db`。
|
这将会将数据库地址设置为 `/path/to/database.db`。
|
||||||
|
|
||||||
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游,并使用 API 进行相关操作。
|
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游,并使用 API 进行相关操作。
|
||||||
@@ -91,22 +94,4 @@ Usage of ./openai-api-route:
|
|||||||
./openai-api-route -add -sk sk-xxxxx -endpoint https://api.openai.com/v1
|
./openai-api-route -add -sk sk-xxxxx -endpoint https://api.openai.com/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
您也可以使用 `/admin/upstreams` 的 HTTP 接口进行控制。
|
另外,您还可以直接编辑数据库中的 `openai_upstreams` 表进行 OpenAI 上游的增删改查管理。改动的上游需要重启负载均衡服务后才能生效。
|
||||||
|
|
||||||
另外,您还可以直接编辑数据库中的 `openai_upstreams` 表。
|
|
||||||
|
|
||||||
## 身份验证
|
|
||||||
|
|
||||||
### 身份验证中间件流程
|
|
||||||
|
|
||||||
1. 从请求头中获取`Authorization`字段的值。
|
|
||||||
2. 检查`Authorization`字段的值是否以`"Bearer"`开头。
|
|
||||||
- 如果不是,则返回错误信息:"authorization header should start with 'Bearer'"(HTTP 状态码 403)。
|
|
||||||
3. 去除`Authorization`字段值开头的`"Bearer"`和前后的空格。
|
|
||||||
4. 将剩余的值与预先设置的身份验证配置进行比较。
|
|
||||||
- 如果不匹配,则返回错误信息:"wrong authorization header"(HTTP 状态码 403)。
|
|
||||||
5. 如果身份验证通过,则返回`nil`。
|
|
||||||
|
|
||||||
## 上游管理
|
|
||||||
|
|
||||||
没什么好说的,直接操作数据库 `openai_upstreams` 表,改动立即生效
|
|
||||||
|
|||||||
205
main.go
205
main.go
@@ -1,16 +1,9 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -41,6 +34,11 @@ func main() {
|
|||||||
log.Fatal("Failed to connect to database")
|
log.Fatal("Failed to connect to database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load all upstreams
|
||||||
|
upstreams := make([]OPENAI_UPSTREAM, 0)
|
||||||
|
db.Find(&upstreams)
|
||||||
|
log.Println("Load upstreams number:", len(upstreams))
|
||||||
|
|
||||||
err = initconfig(db)
|
err = initconfig(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@@ -80,7 +78,7 @@ func main() {
|
|||||||
|
|
||||||
// metrics
|
// metrics
|
||||||
m := ginmetrics.GetMonitor()
|
m := ginmetrics.GetMonitor()
|
||||||
// m.SetMetricPath("/debug/metrics")
|
m.SetMetricPath("/v1/metrics")
|
||||||
m.Use(engine)
|
m.Use(engine)
|
||||||
|
|
||||||
// error handle middleware
|
// error handle middleware
|
||||||
@@ -127,186 +125,25 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// get load balance policy
|
for index, upstream := range upstreams {
|
||||||
policy := ConfigKV{Value: "main"}
|
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||||
db.Take(&policy, "key = ?", "policy")
|
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||||
log.Println("policy is", policy.Value)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
upstream := OPENAI_UPSTREAM{}
|
shouldResponse := index == len(upstreams)-1
|
||||||
|
|
||||||
// choose openai upstream
|
if len(upstreams) == 1 {
|
||||||
switch policy.Value {
|
upstream.Timeout = 120
|
||||||
case "main":
|
}
|
||||||
db.Order("failed_count, success_count desc").First(&upstream)
|
|
||||||
case "random":
|
|
||||||
// randomly select one upstream
|
|
||||||
db.Order("random()").Take(&upstream)
|
|
||||||
case "random_available":
|
|
||||||
// randomly select one non-failed upstream
|
|
||||||
db.Where("failed_count = ?", 0).Order("random()").Take(&upstream)
|
|
||||||
case "round_robin":
|
|
||||||
// iterates each upstream
|
|
||||||
db.Order("last_call_success_time").First(&upstream)
|
|
||||||
case "round_robin_available":
|
|
||||||
db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream)
|
|
||||||
default:
|
|
||||||
c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value))
|
|
||||||
}
|
|
||||||
|
|
||||||
// do check
|
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||||
log.Println("upstream is", upstream.SK, upstream.Endpoint)
|
|
||||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
|
||||||
c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
record.UpstreamID = upstream.ID
|
|
||||||
|
|
||||||
// reverse proxy
|
|
||||||
remote, err := url.Parse(upstream.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(remote)
|
|
||||||
proxy.Director = nil
|
|
||||||
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
|
||||||
in := proxyRequest.In
|
|
||||||
out := proxyRequest.Out
|
|
||||||
|
|
||||||
// read request body
|
|
||||||
body, err := io.ReadAll(in.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// record chat message from user
|
break
|
||||||
record.Body = string(body)
|
|
||||||
|
|
||||||
out.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
|
|
||||||
out.Host = remote.Host
|
|
||||||
out.URL.Scheme = remote.Scheme
|
|
||||||
out.URL.Host = remote.Host
|
|
||||||
out.URL.Path = in.URL.Path
|
|
||||||
out.Header = http.Header{}
|
|
||||||
out.Header.Set("Host", remote.Host)
|
|
||||||
if upstream.SK == "asis" {
|
|
||||||
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
} else {
|
|
||||||
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
|
||||||
}
|
|
||||||
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
var buf bytes.Buffer
|
|
||||||
var contentType string
|
|
||||||
proxy.ModifyResponse = func(r *http.Response) error {
|
|
||||||
record.Status = r.StatusCode
|
|
||||||
r.Header.Del("Access-Control-Allow-Origin")
|
|
||||||
r.Header.Del("Access-Control-Allow-Methods")
|
|
||||||
r.Header.Del("Access-Control-Allow-Headers")
|
|
||||||
r.Header.Set("Access-Control-Allow-Origin", "*")
|
|
||||||
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
|
||||||
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
|
||||||
|
|
||||||
if r.StatusCode != 200 {
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
record.Response = "failed to read response from upstream " + err.Error()
|
|
||||||
return errors.New(record.Response)
|
|
||||||
}
|
|
||||||
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
|
|
||||||
record.Status = r.StatusCode
|
|
||||||
return fmt.Errorf(record.Response)
|
|
||||||
}
|
|
||||||
// count success
|
|
||||||
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
|
|
||||||
contentType = r.Header.Get("content-type")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
log.Println("Error", err, upstream.SK, upstream.Endpoint)
|
|
||||||
|
|
||||||
log.Println("debug", r)
|
|
||||||
|
|
||||||
// abort to error handle
|
|
||||||
c.AbortWithError(502, err)
|
|
||||||
|
|
||||||
log.Println("response is", r.Response)
|
|
||||||
|
|
||||||
if record.Status == 0 {
|
|
||||||
record.Status = 502
|
|
||||||
}
|
|
||||||
if record.Response == "" {
|
|
||||||
record.Response = err.Error()
|
|
||||||
}
|
|
||||||
if r.Response != nil {
|
|
||||||
record.Status = r.Response.StatusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func() {
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
log.Println("Panic recover :", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
|
||||||
}()
|
|
||||||
|
|
||||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
|
||||||
if err != nil {
|
|
||||||
record.Response = "failed to read response from upstream " + err.Error()
|
|
||||||
log.Println(record.Response)
|
|
||||||
} else {
|
|
||||||
|
|
||||||
// record response
|
|
||||||
// stream mode
|
|
||||||
if strings.HasPrefix(contentType, "text/event-stream") {
|
|
||||||
for _, line := range strings.Split(string(resp), "\n") {
|
|
||||||
chunk := StreamModeChunk{}
|
|
||||||
line = strings.TrimPrefix(line, "data:")
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(line), &chunk)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(chunk.Choices) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
record.Response += chunk.Choices[0].Delta.Content
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(contentType, "application/json") {
|
|
||||||
var fetchResp FetchModeResponse
|
|
||||||
err := json.Unmarshal(resp, &fetchResp)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Error parsing fetch response:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
|
|
||||||
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(fetchResp.Choices) == 0 {
|
|
||||||
log.Println("Error: fetch response choice length is 0")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Response = fetchResp.Choices[0].Message.Content
|
|
||||||
} else {
|
|
||||||
log.Println("Unknown content type", contentType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(record.Body) > 1024*512 {
|
|
||||||
record.Body = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Record result:", record.Status, record.Response)
|
log.Println("Record result:", record.Status, record.Response)
|
||||||
@@ -314,8 +151,8 @@ func main() {
|
|||||||
if db.Create(&record).Error != nil {
|
if db.Create(&record).Error != nil {
|
||||||
log.Println("Error to save record:", record)
|
log.Println("Error to save record:", record)
|
||||||
}
|
}
|
||||||
if record.Status != 200 && record.Response != "context canceled" {
|
if record.Status != 200 {
|
||||||
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, upstream.Endpoint, record.Status, record.Response)
|
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
|
||||||
go sendFeishuMessage(errMessage)
|
go sendFeishuMessage(errMessage)
|
||||||
go sendMatrixMessage(errMessage)
|
go sendMatrixMessage(errMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
223
process.go
Normal file
223
process.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
|
var errCtx error
|
||||||
|
|
||||||
|
record.UpstreamID = upstream.ID
|
||||||
|
record.Response = ""
|
||||||
|
record.Authorization = upstream.SK
|
||||||
|
// [TODO] record request body
|
||||||
|
|
||||||
|
// reverse proxy
|
||||||
|
remote, err := url.Parse(upstream.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
haveResponse := false
|
||||||
|
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(remote)
|
||||||
|
proxy.Director = nil
|
||||||
|
var inBody []byte
|
||||||
|
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
||||||
|
in := proxyRequest.In
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
|
||||||
|
|
||||||
|
out := proxyRequest.Out
|
||||||
|
|
||||||
|
// read request body
|
||||||
|
inBody, err = io.ReadAll(in.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// record chat message from user
|
||||||
|
record.Body = string(inBody)
|
||||||
|
requestBody, requestBodyOK := ParseRequestBody(inBody)
|
||||||
|
// record if parse success
|
||||||
|
if requestBodyOK == nil {
|
||||||
|
record.Model = requestBody.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// set timeout, default is 60 second
|
||||||
|
timeout := 60 * time.Second
|
||||||
|
if requestBodyOK == nil && requestBody.Stream {
|
||||||
|
timeout = 5 * time.Second
|
||||||
|
}
|
||||||
|
if len(inBody) > 1024*128 {
|
||||||
|
timeout = 20 * time.Second
|
||||||
|
}
|
||||||
|
if upstream.Timeout > 0 {
|
||||||
|
// convert upstream.Timeout(second) to nanosecond
|
||||||
|
timeout = time.Duration(upstream.Timeout) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeout out request
|
||||||
|
go func() {
|
||||||
|
time.Sleep(timeout)
|
||||||
|
if !haveResponse {
|
||||||
|
log.Println("Timeout upstream", upstream.Endpoint)
|
||||||
|
errCtx = errors.New("timeout")
|
||||||
|
if shouldResponse {
|
||||||
|
c.AbortWithError(502, errCtx)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
out.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||||
|
|
||||||
|
out.Host = remote.Host
|
||||||
|
out.URL.Scheme = remote.Scheme
|
||||||
|
out.URL.Host = remote.Host
|
||||||
|
out.URL.Path = in.URL.Path
|
||||||
|
out.Header = http.Header{}
|
||||||
|
out.Header.Set("Host", remote.Host)
|
||||||
|
if upstream.SK == "asis" {
|
||||||
|
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
} else {
|
||||||
|
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
||||||
|
}
|
||||||
|
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
var contentType string
|
||||||
|
proxy.ModifyResponse = func(r *http.Response) error {
|
||||||
|
haveResponse = true
|
||||||
|
record.Status = r.StatusCode
|
||||||
|
if !shouldResponse && r.StatusCode != 200 {
|
||||||
|
log.Println("upstream return not 200 and should not response", r.StatusCode)
|
||||||
|
return errors.New("upstream return not 200 and should not response")
|
||||||
|
}
|
||||||
|
r.Header.Del("Access-Control-Allow-Origin")
|
||||||
|
r.Header.Del("Access-Control-Allow-Methods")
|
||||||
|
r.Header.Del("Access-Control-Allow-Headers")
|
||||||
|
r.Header.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||||
|
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||||
|
|
||||||
|
if r.StatusCode != 200 {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
record.Response = "failed to read response from upstream " + err.Error()
|
||||||
|
return errors.New(record.Response)
|
||||||
|
}
|
||||||
|
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
|
||||||
|
record.Status = r.StatusCode
|
||||||
|
return fmt.Errorf(record.Response)
|
||||||
|
}
|
||||||
|
// count success
|
||||||
|
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
|
||||||
|
contentType = r.Header.Get("content-type")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
haveResponse = true
|
||||||
|
log.Println("Error", err, upstream.SK, upstream.Endpoint)
|
||||||
|
|
||||||
|
errCtx = err
|
||||||
|
|
||||||
|
// abort to error handle
|
||||||
|
if shouldResponse {
|
||||||
|
c.AbortWithError(502, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("response is", r.Response)
|
||||||
|
|
||||||
|
if record.Status == 0 {
|
||||||
|
record.Status = 502
|
||||||
|
}
|
||||||
|
if record.Response == "" {
|
||||||
|
record.Response = err.Error()
|
||||||
|
}
|
||||||
|
if r.Response != nil {
|
||||||
|
record.Status = r.Response.StatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
|
||||||
|
// return context error
|
||||||
|
if errCtx != nil {
|
||||||
|
// fix inrequest body
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||||
|
return errCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||||
|
if err != nil {
|
||||||
|
record.Response = "failed to read response from upstream " + err.Error()
|
||||||
|
log.Println(record.Response)
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// record response
|
||||||
|
// stream mode
|
||||||
|
if strings.HasPrefix(contentType, "text/event-stream") {
|
||||||
|
for _, line := range strings.Split(string(resp), "\n") {
|
||||||
|
chunk := StreamModeChunk{}
|
||||||
|
line = strings.TrimPrefix(line, "data:")
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(line), &chunk)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(chunk.Choices) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record.Response += chunk.Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(contentType, "application/json") {
|
||||||
|
var fetchResp FetchModeResponse
|
||||||
|
err := json.Unmarshal(resp, &fetchResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error parsing fetch response:", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
|
||||||
|
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(fetchResp.Choices) == 0 {
|
||||||
|
log.Println("Error: fetch response choice length is 0")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
record.Response = fetchResp.Choices[0].Message.Content
|
||||||
|
} else {
|
||||||
|
log.Println("Unknown content type", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(record.Body) > 1024*512 {
|
||||||
|
record.Body = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ type Record struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
IP string
|
IP string
|
||||||
Body string `gorm:"serializer:json"`
|
Body string `gorm:"serializer:json"`
|
||||||
|
Model string
|
||||||
Response string
|
Response string
|
||||||
ElapsedTime time.Duration
|
ElapsedTime time.Duration
|
||||||
Status int
|
Status int
|
||||||
|
|||||||
24
request_body.go
Normal file
24
request_body.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestBody struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseRequestBody(data []byte) (RequestBody, error) {
|
||||||
|
ret := RequestBody{
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestBody RequestBody
|
||||||
|
err := json.Unmarshal(data, &requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return requestBody, nil
|
||||||
|
}
|
||||||
@@ -9,4 +9,5 @@ type OPENAI_UPSTREAM struct {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
SK string `gorm:"index:idx_sk_endpoint,unique"` // key
|
SK string `gorm:"index:idx_sk_endpoint,unique"` // key
|
||||||
Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint
|
Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint
|
||||||
|
Timeout int64 // timeout in seconds
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user