55 Commits

Author SHA1 Message Date
495f32610b refactor: pipe the read and write process
this refactor simplify the process logic and fix several bugs and
performance issue.

bug fixed:
- cors headers not being sent in some situation
performance:
- perform upstream reqeust while clien is uploading content
2024-05-27 14:47:00 +08:00
45bba95f5d Merge remote-tracking branch 'comp/master' 2024-04-09 19:07:45 +08:00
75ff8fbc2e Refactor CORS handling and remove response's CORS headers 2024-04-09 19:07:28 +08:00
66758e0008 更新.gitlab-ci.yml文件 2024-04-08 10:29:08 +00:00
40fc2067a5 Merge remote-tracking branch 'comp/master' 2024-04-08 18:22:10 +08:00
1a56101ca8 add dockerignore 2024-04-08 18:21:55 +08:00
e373e3ac63 更新.gitlab-ci.yml文件 2024-04-08 07:48:05 +00:00
24a2e609f8 更新.gitlab-ci.yml文件 2024-04-08 05:11:47 +00:00
e442303847 更新.gitlab-ci.yml文件 2024-04-08 03:52:19 +00:00
8b95fbb5da 更新.gitlab-ci.yml文件 2024-04-08 03:40:32 +00:00
34aa4babc4 更新.gitlab-ci.yml文件 2024-04-08 03:21:24 +00:00
e6ff1f5ca4 更新.gitlab-ci.yml文件 2024-04-08 03:10:24 +00:00
6b6f245e45 Update README.md with complex configuration example 2024-04-08 11:04:29 +08:00
995eea9d67 Update README.md 2024-02-18 17:28:16 +08:00
db7f0eb316 timeout 2024-02-18 16:45:37 +08:00
990628b455 bro gooooo 2024-02-17 00:30:27 +08:00
e8b89fc41a record all json request body 2024-02-16 22:42:48 +08:00
46ee30ced7 use path as default model name 2024-02-16 22:31:32 +08:00
f2e32340e3 fix typo 2024-02-16 17:47:40 +08:00
ca386f8302 record whisper response 2024-02-15 16:52:33 +08:00
3385f9af08 fix: replicate mistral prompt 2024-01-23 16:38:45 +08:00
8fa7fa79be less: replicate log 2024-01-23 15:56:48 +08:00
49169452fe fix: read config upstream type default value 2024-01-23 15:52:46 +08:00
33f341026f fix: replicate response format 2024-01-23 15:43:48 +08:00
b1a9d6b685 add: support replicate 2024-01-23 15:20:22 +08:00
1fc17daa35 deny allow list if unknown model 2024-01-21 15:37:26 +08:00
b1ab9c3d7b add allow & deny model list, fix cors on error 2024-01-16 17:11:25 +08:00
8672899a58 change log format 2024-01-16 16:23:27 +08:00
3a59433f66 append error 2024-01-16 12:03:23 +08:00
873548a7d0 fix: body larger than 1024*128 2024-01-16 11:32:06 +08:00
2a2d907b0d fix: podman image build 2024-01-11 16:43:40 +08:00
2bbe98e694 fix duplicated cors headers 2024-01-04 19:03:58 +08:00
9fdbf259c0 truncate requset body if too long 2024-01-04 18:41:26 +08:00
97926087bb async record request 2024-01-04 18:37:01 +08:00
fc5a8d55fa fix: process error content type 2023-12-22 15:08:24 +08:00
b1e3a97aad fix: cors and content-type on error 2023-12-22 14:24:11 +08:00
04a2e4c12d use cors middleware 2023-12-22 13:23:41 +08:00
b8ebbed5d6 fix: recognize '/v1' prefix 2023-12-07 00:51:21 +08:00
412aefdacc fix: record resp time 2023-11-30 10:24:00 +08:00
2c3532f12f add hostname 2023-11-30 10:18:01 +08:00
7d93332e51 fix: record body json 2023-11-30 10:06:10 +08:00
0dbd898532 update config yaml 2023-11-29 14:40:20 +08:00
7b74818676 update config to file 2023-11-29 14:31:11 +08:00
0a2a7376f1 update dep 2023-11-28 14:36:14 +08:00
44a966e6f4 clean code 2023-11-28 14:04:42 +08:00
87244d4dc2 update docs 2023-11-28 10:42:06 +08:00
11bf18391e update docs 2023-11-28 10:22:12 +08:00
0785d43ff1 support postgres 2023-11-28 10:19:07 +08:00
de3bea06a7 convert config to yaml file 2023-11-27 18:08:15 +08:00
3cc507a767 add docker file 2023-11-27 17:53:28 +08:00
dad4ad2b97 update readme.md 2023-11-27 17:28:29 +08:00
fb19d8a353 upstreams.yaml 2023-11-27 17:19:08 +08:00
4125c78f33 record response time 2023-11-17 14:53:49 +08:00
31eed99025 support multiple auth header 2023-11-16 15:48:09 +08:00
6c9eab09e2 record authorization 2023-11-16 15:42:37 +08:00
18 changed files with 1158 additions and 498 deletions

4
.dockerignore Normal file
View File

@@ -0,0 +1,4 @@
openai-api-route
db.sqlite
/config.yaml
/.*

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
openai-api-route
db.sqlite
/config.yaml

52
.gitlab-ci.yml Normal file
View File

@@ -0,0 +1,52 @@
# To contribute improvements to CI/CD templates, please follow the Development guide at:
# https://docs.gitlab.com/ee/development/cicd/templates.html
# This specific template is located at:
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Docker.gitlab-ci.yml
# Build a Docker image with CI/CD and push to the GitLab registry.
# Docker-in-Docker documentation: https://docs.gitlab.com/ee/ci/docker/using_docker_build.html
#
# This template uses one generic job with conditional builds
# for the default branch and all other (MR) branches.
docker-build:
# Use the official docker image.
image: docker:cli
stage: build
services:
- docker:dind
variables:
CI_REGISTRY: registry.waykey.net:7999
CI_REGISTRY_IMAGE: $CI_REGISTRY/spiderman/datamining/openai-api-route
DOCKER_IMAGE_NAME: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
before_script:
- docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
# All branches are tagged with $DOCKER_IMAGE_NAME (defaults to commit ref slug)
# Default branch is also tagged with `latest`
script:
- docker build --pull -t "$DOCKER_IMAGE_NAME" .
- docker push "$DOCKER_IMAGE_NAME"
- |
if [[ "$CI_COMMIT_BRANCH" == "$CI_DEFAULT_BRANCH" ]]; then
docker tag "$DOCKER_IMAGE_NAME" "$CI_REGISTRY_IMAGE:latest"
docker push "$CI_REGISTRY_IMAGE:latest"
fi
# Run this job in a branch where a Dockerfile exists
rules:
- if: $CI_COMMIT_BRANCH
exists:
- Dockerfile
deploy:
environment: production
image: kroniak/ssh-client
stage: deploy
before_script:
- chmod 600 $CI_SSH_PRIVATE_KEY
script:
- ssh -o StrictHostKeyChecking=no -i $CI_SSH_PRIVATE_KEY root@192.168.1.13 "cd /mnt/data/srv/openai-api-route && podman-compose pull && podman-compose down && podman-compose up -d"
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
exists:
- Dockerfile

17
Dockerfile Normal file
View File

@@ -0,0 +1,17 @@
FROM docker.io/golang:1.21 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN make
FROM alpine
COPY --from=builder /app/openai-api-route /openai-api-route
ENTRYPOINT ["/openai-api-route"]

183
README.md
View File

@@ -7,15 +7,134 @@
- 自定义 Authorization 验证头
- 支持所有类型的接口 (`/v1/*`)
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
- 按照定义顺序请求 OpenAI 上游
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。对于其他请求使用60秒超时。
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
- 请求出错时发送 飞书 或 Matrix 消息通知
- 按照定义顺序请求 OpenAI 上游,出错或超时自动按顺序尝试下一个
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用更短的超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
- 有选择地记录请求内容、请求头、使用的上游、IP 地址、响应时间以及响应等内容。具体记录策略请参阅 [记录策略](#记录策略) 一节
- 请求出错时发送 飞书 或 Matrix 平台的消息通知
- 支持 Replicate 平台上的 mistral 模型beta
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
## 配置文件
默认情况下程序会使用当前目录下的 `config.yaml` 文件,您可以通过使用 `-config your-config.yaml` 参数指定配置文件路径。
以下是一个配置文件示例,你可以在 `config.sample.yaml` 文件中找到同样的内容
```yaml
authorization: woshimima
# 默认超时时间,默认 120 秒,流式请求是 10 秒
timeout: 120
stream_timeout: 10
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: hahaha
endpoint: "https://localhost:8888/v1"
allow:
# whisper 等非 JSON API 识别不到 model则使用 URL 路径作为模型名称
- /v1/audio/transcriptions
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
timeout: 120 # 请求超时时间默认120秒
stream_timeout: 10 # 如果识别到 stream: true, 则使用该超时时间
allow: # 可选的模型白名单
- gpt-3.5-trubo
- gpt-3.5-trubo-0613
# 您可以设置很多个上游,程序将依次按顺序尝试
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
deny:
- gpt-4
- sk: "key_for_replicate"
type: replicate
allow:
- mistralai/mixtral-8x7b-instruct-v0.1
```
### 配置多个验证头
您可以使用英文逗号 `,` 分割多个验证头。每个验证头都是有效的,程序会记录每个请求使用的验证头
```yaml
authorization: woshimima,iampassword
```
您也可以为上游单独设置验证头
```yaml
authorization: woshimima,iampassword
upstreams:
- sk: key
authorization: woshimima
```
如此,只有携带 `woshimima` 验证头的用户可以使用该上游。
### 复杂配置示例
```yaml
# 默认验证头
authorization: woshimima
upstreams:
# 允许所有人使用的文字转语音
- sk: xxx
endpoint: http://localhost:5000/v1
noauth: true
allow:
- /v1/audio/transcriptions
# guest 专用的 gpt-3.5-turbo-0125 模型
- sk:
endpoint: https://api.xxx.local/v1
authorization: guest
allow:
- gpt-3.5-turbo-0125
```
## 部署方法
有两种推荐的部署方法:
1. 使用预先构建好的容器 `docker.io/heimoshuiyu/openai-api-route:latest`
2. 自行编译
### 使用容器运行
> 注意,如果您使用 sqlite 数据库,您可能还需要修改配置文件以将 SQLite 数据库文件放置在数据卷中。
```bash
docker run -d --name openai-api-route -v /path/to/config.yaml:/config.yaml docker.io/heimoshuiyu/openai-api-route:latest
```
使用 Docker Compose
```yaml
version: '3'
services:
openai-api-route:
image: docker.io/heimoshuiyu/openai-api-route:latest
ports:
- 8888:8888
volumes:
- ./config.yaml:/config.yaml
```
### 编译
以下是编译和运行该负载均衡 API 的步骤:
@@ -40,58 +159,28 @@
./openai-api-route
```
默认情况下API 将会在本地的 8888 端口进行监听。
## 模型允许与屏蔽列表
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。
```
./openai-api-route -addr 0.0.0.0:8080
```
如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台
这将会将监听地址设置为 0.0.0.0:8080。
## 超时策略
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件
在处理上游请求时,超时策略是确保服务稳定性和响应性的关键因素。本服务通过配置文件中的 `Upstreams` 部分来定义多个上游服务器。每个上游服务器都有自己的 `Endpoint` 和 `SK`(可能是密钥或特殊标识)。服务会按照配置文件中的顺序依次尝试每个上游服务器,直到请求成功或所有上游服务器都已尝试
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
### 单一上游配置
```
./openai-api-route -database /path/to/database.db
```
当配置文件中只定义了一个上游服务器时,该上游的超时时间将被设置为 120 秒。这意味着,如果请求没有在 120 秒内得到上游服务器的响应,服务将会中止该请求并可能返回错误。
这将会将数据库地址设置为 `/path/to/database.db`。
### 多上游配置
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游并使用 API 进行相关操作
如果配置文件中定义了多个上游服务器,服务将会按照定义的顺序依次尝试每个上游。对于每个上游服务器,服务会检查其 `Endpoint` 和 `SK` 是否有效。如果任一字段为空,服务将返回 500 错误,并记录无效的上游信息
### 运行
### 超时策略细节
以下是运行命令的用法
服务在处理请求时会根据不同的条件设置不同的超时时间。超时时间是指服务等待上游服务器响应的最大时间。以下是超时时间的设置规则
```
Usage of ./openai-api-route:
-add
添加一个 OpenAI 上游
-addr string
监听地址(默认为 ":8888"
-database string
数据库地址(默认为 "./db.sqlite"
-endpoint string
OpenAI API 基地址(默认为 "https://api.openai.com/v1"
-list
列出所有上游
-noauth
不检查传入的授权头
-sk string
OpenAI API 密钥sk-xxxxx
```
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建
### 上游管理
您可以使用以下命令添加一个上游:
```bash
./openai-api-route -add -sk sk-xxxxx -endpoint https://api.openai.com/v1
```
另外,您还可以直接编辑数据库中的 `openai_upstreams` 表进行 OpenAI 上游的增删改查管理。改动的上游需要重启负载均衡服务后才能生效。
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求

28
auth.go
View File

@@ -2,30 +2,14 @@ package main
import (
"errors"
"log"
"strings"
"github.com/gin-gonic/gin"
)
func handleAuth(c *gin.Context) error {
var err error
authorization := c.Request.Header.Get("Authorization")
if !strings.HasPrefix(authorization, "Bearer") {
err = errors.New("authorization header should start with 'Bearer'")
c.AbortWithError(403, err)
return err
func checkAuth(authorization string, config string) error {
for _, auth := range strings.Split(config, ",") {
if authorization == strings.Trim(auth, " ") {
return nil
}
}
authorization = strings.Trim(authorization[len("Bearer"):], " ")
log.Println("Received authorization", authorization)
if authorization != authConfig.Value {
err = errors.New("wrong authorization header")
c.AbortWithError(403, err)
return err
}
return nil
return errors.New("wrong authorization header")
}

View File

@@ -1,49 +0,0 @@
package main
import (
"errors"
"log"
"gorm.io/gorm"
)
// K-V struct to store program's config
type ConfigKV struct {
gorm.Model
Key string `gorm:"unique"`
Value string
}
// init db
func initconfig(db *gorm.DB) error {
var err error
err = db.AutoMigrate(&ConfigKV{})
if err != nil {
return err
}
// config list and their default values
configs := make(map[string]string)
configs["authorization"] = "woshimima"
configs["policy"] = "main"
for key, value := range configs {
kv := ConfigKV{}
err = db.Take(&kv, "key = ?", key).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Println("Missing config", key, "creating with value", value)
kv.Key = key
kv.Value = value
if err = db.Create(&kv).Error; err != nil {
return err
}
} else {
return err
}
}
}
return nil
}

23
config.sample.yaml Normal file
View File

@@ -0,0 +1,23 @@
authorization: woshimima
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
deny: ["gpt-4"] # 可选的模型黑名单
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
- sk: "key_for_replicate"
type: replicate
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]

20
cors.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"log"
"github.com/gin-gonic/gin"
)
func sendCORSHeaders(c *gin.Context) {
log.Println("sendCORSHeaders")
if c.Writer.Header().Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if c.Writer.Header().Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if c.Writer.Header().Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
}

View File

@@ -1,5 +0,0 @@
package main
// declare global variable
var authConfig ConfigKV

20
go.mod
View File

@@ -4,9 +4,12 @@ go 1.20
require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/penglongli/gin-metrics v0.1.10
golang.org/x/net v0.10.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4
gorm.io/driver/sqlite v1.5.2
gorm.io/gorm v1.25.2
gorm.io/gorm v1.25.5
)
require (
@@ -22,6 +25,9 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@@ -33,19 +39,17 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/penglongli/gin-metrics v0.1.10 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.12.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

34
go.sum
View File

@@ -148,13 +148,17 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -178,7 +182,9 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
@@ -232,6 +238,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -272,8 +280,8 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -389,8 +397,8 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -398,8 +406,8 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -445,7 +453,6 @@ golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -527,11 +534,10 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
@@ -543,10 +549,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc=
gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

240
main.go
View File

@@ -1,73 +1,70 @@
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/penglongli/gin-metrics/ginmetrics"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// global config
var config Config
func main() {
dbAddr := flag.String("database", "./db.sqlite", "Database address")
listenAddr := flag.String("addr", ":8888", "Listening address")
addMode := flag.Bool("add", false, "Add an OpenAI upstream")
configFile := flag.String("config", "./config.yaml", "Config file")
listMode := flag.Bool("list", false, "List all upstream")
sk := flag.String("sk", "", "OpenAI API key (sk-xxxxx)")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
endpoint := flag.String("endpoint", "https://api.openai.com/v1", "OpenAI API base")
flag.Parse()
log.Println("Service starting")
// connect to database
db, err := gorm.Open(sqlite.Open(*dbAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatal("Failed to connect to database")
}
log.Println("[main]: Service starting")
// load all upstreams
upstreams := make([]OPENAI_UPSTREAM, 0)
db.Find(&upstreams)
log.Println("Load upstreams number:", len(upstreams))
config = readConfig(*configFile)
log.Println("[main]: Load upstreams number:", len(config.Upstreams))
err = initconfig(db)
if err != nil {
log.Fatal(err)
}
db.AutoMigrate(&OPENAI_UPSTREAM{})
db.AutoMigrate(&Record{})
log.Println("Auto migrate database done")
if *addMode {
if *sk == "" {
log.Fatal("Missing --sk flag")
}
newUpstream := OPENAI_UPSTREAM{}
newUpstream.SK = *sk
newUpstream.Endpoint = *endpoint
err = db.Create(&newUpstream).Error
// connect to database
var db *gorm.DB
var err error
switch config.DBType {
case "sqlite":
db, err = gorm.Open(sqlite.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatal("Can not add upstream", err)
log.Fatalf("[main]: Error to connect sqlite database: %s", err)
}
log.Println("Successuflly add upstream", *sk, *endpoint)
return
case "postgres":
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatalf("[main]: Error to connect postgres database: %s", err)
}
default:
log.Fatalf("[main]: Unsupported database type: '%s'", config.DBType)
}
db.AutoMigrate(&Record{})
log.Println("[main]: Auto migrate database done")
if *listMode {
result := make([]OPENAI_UPSTREAM, 0)
db.Find(&result)
fmt.Println("SK\tEndpoint")
for _, upstream := range result {
for _, upstream := range config.Upstreams {
fmt.Println(upstream.SK, upstream.Endpoint)
}
return
@@ -81,6 +78,9 @@ func main() {
m.SetMetricPath("/v1/metrics")
m.Use(engine)
// CORS middleware
// engine.Use(corsMiddleware())
// error handle middleware
engine.Use(func(c *gin.Context) {
c.Next()
@@ -89,74 +89,158 @@ func main() {
}
errText := strings.Join(c.Errors.Errors(), "\n")
c.JSON(-1, gin.H{
"error": errText,
"openai-api-route error": errText,
})
})
// CORS handler
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
header := ctx.Writer.Header()
header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
// set cros header
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
ctx.AbortWithStatus(200)
})
// get authorization config from db
db.Take(&authConfig, "key = ?", "authorization")
engine.POST("/v1/*any", func(c *gin.Context) {
var err error
hostname, _ := os.Hostname()
if config.Hostname != "" {
hostname = config.Hostname
}
record := Record{
IP: c.ClientIP(),
Hostname: hostname,
CreatedAt: time.Now(),
Authorization: c.Request.Header.Get("Authorization"),
}
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
// check authorization header
if !*noauth {
if handleAuth(c) != nil {
return
}
UserAgent: c.Request.Header.Get("User-Agent"),
Model: c.Request.URL.Path,
}
for index, upstream := range upstreams {
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
authorization := c.Request.Header.Get("Authorization")
if strings.HasPrefix(authorization, "Bearer") {
authorization = strings.Trim(authorization[len("Bearer"):], " ")
} else {
authorization = strings.Trim(authorization, " ")
log.Println("[auth] Warning: authorization header should start with 'Bearer'")
}
log.Println("Received authorization '" + authorization + "'")
availUpstreams := make([]OPENAI_UPSTREAM, 0)
for _, upstream := range config.Upstreams {
if upstream.SK == "" {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
continue
}
shouldResponse := index == len(upstreams)-1
if !*noauth && !upstream.Noauth {
if checkAuth(authorization, upstream.Authorization) != nil {
continue
}
}
if len(upstreams) == 1 {
availUpstreams = append(availUpstreams, upstream)
}
if len(availUpstreams) == 0 {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
}
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
wrapedBody := false
for index, _upstream := range availUpstreams {
// copy
upstream := _upstream
record.UpstreamEndpoint = upstream.Endpoint
record.UpstreamSK = upstream.SK
shouldResponse := index == len(config.Upstreams)-1
if len(availUpstreams) == 1 {
// [todo] copy problem
upstream.Timeout = 120
}
err = processRequest(c, &upstream, &record, shouldResponse)
if err != nil {
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
continue
// buffer for incoming request
if !wrapedBody {
log.Println("[processRequest.begin]: wrap request body")
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
wrapedBody = true
} else {
log.Println("[processRequest.begin]: reuse request body")
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
}
break
if upstream.Type == "replicate" {
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
} else if upstream.Type == "openai" {
err = processRequest(c, &upstream, &record, shouldResponse)
} else {
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
}
if err == nil {
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
break
}
if err == http.ErrAbortHandler {
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
log.Println(abortErr)
record.Response += abortErr
record.Status = 500
break
}
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
// error process, break
if shouldResponse {
c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
c.AbortWithError(500, err)
}
}
log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
// parse and record request body
requestBodyBytes := bufIO.Bytes()
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
record.Body = string(requestBodyBytes)
}
requestBody, err := ParseRequestBody(requestBodyBytes)
if err != nil {
log.Println("[processRequest.done]: Error to parse request body:", err)
} else {
record.Model = requestBody.Model
}
log.Println("[final]: Record result:", record.Status, record.Response)
record.ElapsedTime = time.Since(record.CreatedAt)
// async record request
go func() {
// encoder headers to record.Headers in json string
headers, _ := json.Marshal(c.Request.Header)
record.Headers = string(headers)
// turncate request if too long
log.Println("[async.record]: body length:", len(record.Body))
if db.Create(&record).Error != nil {
log.Println("[async.record]: Error to save record:", record)
}
}()
if record.Status != 200 {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
errMessage := fmt.Sprintf("[result.error]: IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
go sendFeishuMessage(errMessage)
go sendMatrixMessage(errMessage)
}
})
engine.Run(*listenAddr)
engine.Run(config.Address)
}

View File

@@ -8,215 +8,118 @@ import (
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/context"
)
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error
record.UpstreamID = upstream.ID
record.Response = ""
record.Authorization = upstream.SK
// [TODO] record request body
// reverse proxy
remote, err := url.Parse(upstream.Endpoint)
if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return err
}
haveResponse := false
path := strings.TrimPrefix(c.Request.URL.Path, "/v1")
remote.Path = upstream.URL.Path + path
log.Println("[proxy.begin]:", remote)
log.Println("[proxy.begin]: shouldResposne:", shouldResponse)
proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = nil
var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In
client := &http.Client{}
request := &http.Request{}
request.ContentLength = c.Request.ContentLength
request.Method = c.Request.Method
request.URL = remote
ctx, cancel := context.WithCancel(context.Background())
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
out := proxyRequest.Out
// read request body
inBody, err = io.ReadAll(in.Body)
if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
return
}
// record chat message from user
record.Body = string(inBody)
requestBody, requestBodyOK := ParseRequestBody(inBody)
// record if parse success
if requestBodyOK == nil {
record.Model = requestBody.Model
}
// set timeout, default is 60 second
timeout := 60 * time.Second
if requestBodyOK == nil && requestBody.Stream {
timeout = 5 * time.Second
}
if len(inBody) > 1024*128 {
timeout = 20 * time.Second
}
if upstream.Timeout > 0 {
// convert upstream.Timeout(second) to nanosecond
timeout = time.Duration(upstream.Timeout) * time.Second
}
// timeout out request
go func() {
time.Sleep(timeout)
if !haveResponse {
log.Println("Timeout upstream", upstream.Endpoint)
errCtx = errors.New("timeout")
if shouldResponse {
c.AbortWithError(502, errCtx)
}
cancel()
}
}()
out.Body = io.NopCloser(bytes.NewReader(inBody))
out.Host = remote.Host
out.URL.Scheme = remote.Scheme
out.URL.Host = remote.Host
out.URL.Path = in.URL.Path
out.Header = http.Header{}
out.Header.Set("Host", remote.Host)
if upstream.SK == "asis" {
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
out.Header.Set("Authorization", "Bearer "+upstream.SK)
}
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
var buf bytes.Buffer
var contentType string
proxy.ModifyResponse = func(r *http.Response) error {
haveResponse = true
record.Status = r.StatusCode
if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response")
}
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body)
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
return errors.New(record.Response)
}
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
record.Status = r.StatusCode
return fmt.Errorf(record.Response)
}
// count success
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
contentType = r.Header.Get("content-type")
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
haveResponse = true
log.Println("Error", err, upstream.SK, upstream.Endpoint)
errCtx = err
// abort to error handle
if shouldResponse {
c.AbortWithError(502, err)
}
log.Println("response is", r.Response)
if record.Status == 0 {
record.Status = 502
}
if record.Response == "" {
record.Response = err.Error()
}
if r.Response != nil {
record.Status = r.Response.StatusCode
}
}
proxy.ServeHTTP(c.Writer, c.Request)
// return context error
if errCtx != nil {
// fix inrequest body
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
return errCtx
}
resp, err := io.ReadAll(io.NopCloser(&buf))
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
log.Println(record.Response)
// process header
if upstream.KeepHeader {
request.Header = c.Request.Header
} else {
// record response
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
for _, line := range strings.Split(string(resp), "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(resp, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return nil
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return nil
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return nil
}
record.Response = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
}
request.Header = http.Header{}
}
if len(record.Body) > 1024*512 {
record.Body = ""
// process header authorization
if upstream.SK == "asis" {
request.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
request.Header.Set("Authorization", "Bearer "+upstream.SK)
}
request.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
request.Header.Set("Host", remote.Host)
request.Header.Set("Content-Length", c.Request.Header.Get("Content-Length"))
request.Body = c.Request.Body
resp, err := client.Do(request)
if err != nil {
body := []byte{}
if resp != nil && resp.Body != nil {
body, _ = io.ReadAll(resp.Body)
}
return errors.New(err.Error() + " " + string(body))
}
defer resp.Body.Close()
record.Status = resp.StatusCode
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
record.Status = resp.StatusCode
errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", resp.Status, string(body))
log.Println(errRet)
return errRet
}
// copy response header
for k, v := range resp.Header {
c.Header(k, v[0])
}
sendCORSHeaders(c)
respBodyBuffer := bytes.NewBuffer(make([]byte, 0, 4*1024))
respBodyTeeReader := io.TeeReader(resp.Body, respBodyBuffer)
record.ResponseTime = time.Since(record.CreatedAt)
io.Copy(c.Writer, respBodyTeeReader)
record.ElapsedTime = time.Since(record.CreatedAt)
// parse and record response
if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
var fetchResp FetchModeResponse
err := json.NewDecoder(respBodyBuffer).Decode(&fetchResp)
if err == nil {
if len(fetchResp.Choices) > 0 {
record.Response = fetchResp.Choices[0].Message.Content
}
}
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") {
lines := bytes.Split(respBodyBuffer.Bytes(), []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
line = bytes.TrimPrefix(line, []byte("data:"))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
chunk := StreamModeChunk{}
err = json.Unmarshal(line, &chunk)
if err != nil {
log.Println("[proxy.parseChunkError]:", err)
break
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text") {
body, _ := io.ReadAll(respBodyBuffer)
record.Response = string(body)
} else {
log.Println("[proxy.record]: Unknown content type", resp.Header.Get("Content-Type"))
}
return nil

View File

@@ -1,26 +1,25 @@
package main
import (
"encoding/json"
"log"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Record struct {
ID int64 `gorm:"primaryKey,autoIncrement"`
CreatedAt time.Time
IP string
Body string `gorm:"serializer:json"`
Model string
Response string
ElapsedTime time.Duration
Status int
UpstreamID uint
Authorization string
ID int64 `gorm:"primaryKey,autoIncrement"`
Hostname string
UpstreamEndpoint string
UpstreamSK string
CreatedAt time.Time
IP string
Body string
Model string
Response string
ResponseTime time.Duration
ElapsedTime time.Duration
Status int
Authorization string // the autorization header send by client
UserAgent string
Headers string
}
type StreamModeChunk struct {
@@ -52,61 +51,3 @@ type FetchModeUsage struct {
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
func recordAssistantResponse(contentType string, db *gorm.DB, trackID uuid.UUID, body []byte, elapsedTime time.Duration) {
result := ""
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
resp := string(body)
for _, line := range strings.Split(resp, "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
result += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(body, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return
}
result = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
return
}
log.Println("Record result:", result)
record := Record{}
if db.Find(&record, "id = ?", trackID).Error != nil {
log.Println("Error find request record with trackID:", trackID)
return
}
record.Response = result
record.ElapsedTime = elapsedTime
if db.Save(&record).Error != nil {
log.Println("Error to save record:", record)
return
}
}

24
recovery.go Normal file
View File

@@ -0,0 +1,24 @@
package main
import (
"errors"
"log"
"net/http"
"net/http/httputil"
"github.com/gin-gonic/gin"
)
func ServeHTTP(proxy *httputil.ReverseProxy, w gin.ResponseWriter, r *http.Request) (errReturn error) {
// recovery
defer func() {
if err := recover(); err != nil {
log.Println("[serve.panic]: ", err)
errReturn = errors.New("[serve.panic]: Panic recover in reverse proxy serve HTTP")
}
}()
proxy.ServeHTTP(w, r)
return nil
}

369
replicate.go Normal file
View File

@@ -0,0 +1,369 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
err := _processReplicateRequest(c, upstream, record, shouldResponse)
if shouldResponse {
sendCORSHeaders(c)
if err != nil {
c.AbortWithError(502, err)
}
}
return err
}
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
// read request body
inBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
}
// record request body
// parse request body
inRequest := &OpenAIChatRequest{}
err = json.Unmarshal(inBody, inRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
}
record.Model = inRequest.Model
// check allow model
if len(upstream.Allow) > 0 {
isAllow := false
for _, model := range upstream.Allow {
if model == inRequest.Model {
isAllow = true
break
}
}
if !isAllow {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model not allow")
}
}
// check block model
if len(upstream.Deny) > 0 {
for _, model := range upstream.Deny {
if model == inRequest.Model {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model deny")
}
}
}
// set url
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
log.Println("[processReplicateRequest]: model_url:", model_url)
// create request with default value
outRequest := &ReplicateModelRequest{
Stream: false,
Input: ReplicateModelRequestInput{
Prompt: "",
MaxNewTokens: 1024,
Temperature: 0.6,
Top_p: 0.9,
Top_k: 50,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
PromptTemplate: "{prompt}",
},
}
// copy value from input request
outRequest.Stream = inRequest.Stream
outRequest.Input.Temperature = inRequest.Temperature
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt
systemMessage := ""
userMessage := ""
assistantMessage := ""
for _, message := range inRequest.Messages {
if message.Role == "system" {
if systemMessage != "" {
systemMessage += "\n"
}
systemMessage += message.Content
continue
}
if message.Role == "user" {
if userMessage != "" {
userMessage += "\n"
}
userMessage += message.Content
if systemMessage != "" {
userMessage = systemMessage + "\n" + userMessage
systemMessage = ""
}
continue
}
if message.Role == "assistant" {
if assistantMessage != "" {
assistantMessage += "\n"
}
assistantMessage += message.Content
if outRequest.Input.Prompt != "" {
outRequest.Input.Prompt += "\n"
}
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
} else {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
userMessage = ""
assistantMessage = ""
}
// unknown role
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
}
// final user message
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
userMessage = ""
}
// final assistant message
if assistantMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
// send request
outBody, err := json.Marshal(outRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
}
// http add headers
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
}
// read response body
outBody, err = io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// parse reponse body
outResponse := &ReplicateModelResponse{}
err = json.Unmarshal(outBody, outResponse)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
if outResponse.Stream {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
req.Header.Set("Accept", "text/event-stream")
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result by chunk
var buffer string = ""
var indexCount int64 = 0
for {
if !strings.Contains(buffer, "\n\n") {
// receive chunk
chunk := make([]byte, 1024)
length, err := resp.Body.Read(chunk)
if err == io.EOF {
break
}
if length == 0 {
break
}
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// add chunk to buffer
chunk = bytes.Trim(chunk, "\x00")
buffer += string(chunk)
continue
}
// cut the first chunk by find index
index := strings.Index(buffer, "\n\n")
chunk := buffer[:index]
buffer = buffer[index+2:]
// trim line
chunk = strings.Trim(chunk, "\n")
// ignore hi
if !strings.Contains(chunk, "\n") {
continue
}
// parse chunk to ReplicateModelResultChunk object
chunkObj := &ReplicateModelResultChunk{}
lines := strings.Split(chunk, "\n")
// first line is event
chunkObj.Event = strings.TrimSpace(lines[0])
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
// second line is id
chunkObj.ID = strings.TrimSpace(lines[1])
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
// third line is data
chunkObj.Data = lines[2]
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
record.Response += chunkObj.Data
// done
if chunkObj.Event == "done" {
break
}
sendCORSHeaders(c)
// create OpenAI response chunk
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Delta: OpenAIChatMessage{
Role: "assistant",
Content: chunkObj.Data,
},
},
},
})
c.Writer.Flush()
indexCount += 1
}
sendCORSHeaders(c)
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Delta: OpenAIChatMessage{
Role: "assistant",
Content: "",
},
FinishReason: "stop",
},
},
})
c.Writer.Flush()
indexCount += 1
record.Status = 200
return nil
} else {
var result *ReplicateModelResultGet
for {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result
resultBody, err := io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// parse reponse body
result = &ReplicateModelResultGet{}
err = json.Unmarshal(resultBody, result)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
if result.Status == "processing" || result.Status == "starting" {
time.Sleep(3 * time.Second)
continue
}
break
}
// build openai resposne
openAIResult := &OpenAIChatResponse{
ID: result.ID,
Model: result.Model,
Choices: []OpenAIChatResponseChoice{},
Usage: OpenAIChatResponseUsage{
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
PromptTokens: result.Metrics.InputTokenCount,
},
}
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
Index: 0,
Message: OpenAIChatMessage{
Role: "assistant",
Content: strings.Join(result.Output, ""),
},
FinishReason: "stop",
})
record.Response = strings.Join(result.Output, "")
record.Status = 200
// gin return
sendCORSHeaders(c)
c.JSON(200, openAIResult)
}
return nil
}

View File

@@ -1,13 +1,204 @@
package main
import (
"gorm.io/gorm"
"log"
"net/url"
"os"
"gopkg.in/yaml.v3"
)
// one openai upstream contain a pair of key and endpoint
type OPENAI_UPSTREAM struct {
gorm.Model
SK string `gorm:"index:idx_sk_endpoint,unique"` // key
Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint
Timeout int64 // timeout in seconds
type Config struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
DBType string `yaml:"dbtype"`
DBAddr string `yaml:"dbaddr"`
Authorization string `yaml:"authorization"`
Timeout int64 `yaml:"timeout"`
StreamTimeout int64 `yaml:"stream_timeout"`
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
}
type OPENAI_UPSTREAM struct {
SK string `yaml:"sk"`
Endpoint string `yaml:"endpoint"`
Timeout int64 `yaml:"timeout"`
StreamTimeout int64 `yaml:"stream_timeout"`
Allow []string `yaml:"allow"`
Deny []string `yaml:"deny"`
Type string `yaml:"type"`
KeepHeader bool `yaml:"keep_header"`
Authorization string `yaml:"authorization"`
Noauth bool `yaml:"noauth"`
URL *url.URL
}
func readConfig(filepath string) Config {
var config Config
// read yaml file
data, err := os.ReadFile(filepath)
if err != nil {
log.Fatalf("Error reading YAML file: %s", err)
}
// Unmarshal the YAML into the upstreams slice
err = yaml.Unmarshal(data, &config)
if err != nil {
log.Fatalf("Error unmarshaling YAML: %s", err)
}
// set default value
if config.Address == "" {
log.Println("Address not set, use default value: :8888")
config.Address = ":8888"
}
if config.DBType == "" {
log.Println("DBType not set, use default value: sqlite")
config.DBType = "sqlite"
}
if config.DBAddr == "" {
log.Println("DBAddr not set, use default value: ./db.sqlite")
config.DBAddr = "./db.sqlite"
}
if config.Timeout == 0 {
log.Println("Timeout not set, use default value: 120")
config.Timeout = 120
}
if config.StreamTimeout == 0 {
log.Println("StreamTimeout not set, use default value: 10")
config.StreamTimeout = 10
}
for i, upstream := range config.Upstreams {
// parse upstream endpoint URL
endpoint, err := url.Parse(upstream.Endpoint)
if err != nil {
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
}
config.Upstreams[i].URL = endpoint
if config.Upstreams[i].Type == "" {
config.Upstreams[i].Type = "openai"
}
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
}
// apply authorization from global config if not set
if config.Upstreams[i].Authorization == "" && !config.Upstreams[i].Noauth {
config.Upstreams[i].Authorization = config.Authorization
}
if config.Upstreams[i].Timeout == 0 {
config.Upstreams[i].Timeout = config.Timeout
}
if config.Upstreams[i].StreamTimeout == 0 {
config.Upstreams[i].StreamTimeout = config.StreamTimeout
}
}
return config
}
type OpenAIChatRequest struct {
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
MaxTokens int64 `json:"max_tokens"`
Model string `json:"model"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
Messages []OpenAIChatRequestMessage
}
type OpenAIChatRequestMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type ReplicateModelRequest struct {
Stream bool `json:"stream"`
Input ReplicateModelRequestInput `json:"input"`
}
type ReplicateModelRequestInput struct {
Prompt string `json:"prompt"`
MaxNewTokens int64 `json:"max_new_tokens"`
Temperature float64 `json:"temperature"`
Top_p float64 `json:"top_p"`
Top_k int64 `json:"top_k"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PromptTemplate string `json:"prompt_template"`
}
type ReplicateModelResponse struct {
Model string `json:"model"`
Version string `json:"version"`
Stream bool `json:"stream"`
Error string `json:"error"`
URLS ReplicateModelResponseURLS `json:"urls"`
}
type ReplicateModelResponseURLS struct {
Cancel string `json:"cancel"`
Get string `json:"get"`
Stream string `json:"stream"`
}
type ReplicateModelResultGet struct {
ID string `json:"id"`
Model string `json:"model"`
Version string `json:"version"`
Output []string `json:"output"`
Error string `json:"error"`
Metrics ReplicateModelResultMetrics `json:"metrics"`
Status string `json:"status"`
}
type ReplicateModelResultMetrics struct {
InputTokenCount int64 `json:"input_token_count"`
OutputTokenCount int64 `json:"output_token_count"`
}
type OpenAIChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChatResponseChoice `json:"choices"`
Usage OpenAIChatResponseUsage `json:"usage"`
}
type OpenAIChatResponseUsage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
type OpenAIChatResponseChoice struct {
Index int64 `json:"index"`
Message OpenAIChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ReplicateModelResultChunk struct {
Event string `json:"event"`
ID string `json:"id"`
Data string `json:"data"`
}
type OpenAIChatResponseChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChatResponseChunkChoice `json:"choices"`
}
type OpenAIChatResponseChunkChoice struct {
Index int64 `json:"index"`
Delta OpenAIChatMessage `json:"delta"`
FinishReason string `json:"finish_reason"`
}