Compare commits
40 Commits
7d93332e51
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
495f32610b
|
|||
|
45bba95f5d
|
|||
|
75ff8fbc2e
|
|||
| 66758e0008 | |||
|
40fc2067a5
|
|||
|
1a56101ca8
|
|||
| e373e3ac63 | |||
| 24a2e609f8 | |||
| e442303847 | |||
| 8b95fbb5da | |||
| 34aa4babc4 | |||
| e6ff1f5ca4 | |||
|
6b6f245e45
|
|||
| 995eea9d67 | |||
|
db7f0eb316
|
|||
|
990628b455
|
|||
|
e8b89fc41a
|
|||
|
46ee30ced7
|
|||
|
f2e32340e3
|
|||
|
ca386f8302
|
|||
|
3385f9af08
|
|||
|
8fa7fa79be
|
|||
|
49169452fe
|
|||
|
33f341026f
|
|||
|
b1a9d6b685
|
|||
|
1fc17daa35
|
|||
|
b1ab9c3d7b
|
|||
|
8672899a58
|
|||
|
3a59433f66
|
|||
|
873548a7d0
|
|||
|
2a2d907b0d
|
|||
|
2bbe98e694
|
|||
|
9fdbf259c0
|
|||
|
97926087bb
|
|||
|
fc5a8d55fa
|
|||
|
b1e3a97aad
|
|||
|
04a2e4c12d
|
|||
|
b8ebbed5d6
|
|||
|
412aefdacc
|
|||
|
2c3532f12f
|
@@ -1,3 +1,4 @@
|
|||||||
openai-api-route
|
openai-api-route
|
||||||
db.sqlite
|
db.sqlite
|
||||||
/config.yaml
|
/config.yaml
|
||||||
|
/.*
|
||||||
|
|||||||
52
.gitlab-ci.yml
Normal file
52
.gitlab-ci.yml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# To contribute improvements to CI/CD templates, please follow the Development guide at:
|
||||||
|
# https://docs.gitlab.com/ee/development/cicd/templates.html
|
||||||
|
# This specific template is located at:
|
||||||
|
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Docker.gitlab-ci.yml
|
||||||
|
|
||||||
|
# Build a Docker image with CI/CD and push to the GitLab registry.
|
||||||
|
# Docker-in-Docker documentation: https://docs.gitlab.com/ee/ci/docker/using_docker_build.html
|
||||||
|
#
|
||||||
|
# This template uses one generic job with conditional builds
|
||||||
|
# for the default branch and all other (MR) branches.
|
||||||
|
|
||||||
|
docker-build:
|
||||||
|
# Use the official docker image.
|
||||||
|
image: docker:cli
|
||||||
|
stage: build
|
||||||
|
services:
|
||||||
|
- docker:dind
|
||||||
|
variables:
|
||||||
|
CI_REGISTRY: registry.waykey.net:7999
|
||||||
|
CI_REGISTRY_IMAGE: $CI_REGISTRY/spiderman/datamining/openai-api-route
|
||||||
|
DOCKER_IMAGE_NAME: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
|
||||||
|
before_script:
|
||||||
|
- docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
|
||||||
|
# All branches are tagged with $DOCKER_IMAGE_NAME (defaults to commit ref slug)
|
||||||
|
# Default branch is also tagged with `latest`
|
||||||
|
script:
|
||||||
|
- docker build --pull -t "$DOCKER_IMAGE_NAME" .
|
||||||
|
- docker push "$DOCKER_IMAGE_NAME"
|
||||||
|
- |
|
||||||
|
if [[ "$CI_COMMIT_BRANCH" == "$CI_DEFAULT_BRANCH" ]]; then
|
||||||
|
docker tag "$DOCKER_IMAGE_NAME" "$CI_REGISTRY_IMAGE:latest"
|
||||||
|
docker push "$CI_REGISTRY_IMAGE:latest"
|
||||||
|
fi
|
||||||
|
# Run this job in a branch where a Dockerfile exists
|
||||||
|
rules:
|
||||||
|
- if: $CI_COMMIT_BRANCH
|
||||||
|
exists:
|
||||||
|
- Dockerfile
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
environment: production
|
||||||
|
image: kroniak/ssh-client
|
||||||
|
stage: deploy
|
||||||
|
before_script:
|
||||||
|
- chmod 600 $CI_SSH_PRIVATE_KEY
|
||||||
|
script:
|
||||||
|
- ssh -o StrictHostKeyChecking=no -i $CI_SSH_PRIVATE_KEY root@192.168.1.13 "cd /mnt/data/srv/openai-api-route && podman-compose pull && podman-compose down && podman-compose up -d"
|
||||||
|
rules:
|
||||||
|
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
|
||||||
|
exists:
|
||||||
|
- Dockerfile
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.21 as builder
|
FROM docker.io/golang:1.21 as builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|||||||
200
README.md
200
README.md
@@ -7,15 +7,134 @@
|
|||||||
- 自定义 Authorization 验证头
|
- 自定义 Authorization 验证头
|
||||||
- 支持所有类型的接口 (`/v1/*`)
|
- 支持所有类型的接口 (`/v1/*`)
|
||||||
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
|
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
|
||||||
- 按照定义顺序请求 OpenAI 上游
|
- 按照定义顺序请求 OpenAI 上游,出错或超时自动按顺序尝试下一个
|
||||||
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
|
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用更短的超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
|
||||||
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
|
- 有选择地记录请求内容、请求头、使用的上游、IP 地址、响应时间以及响应等内容。具体记录策略请参阅 [记录策略](#记录策略) 一节
|
||||||
- 请求出错时发送 飞书 或 Matrix 消息通知
|
- 请求出错时发送 飞书 或 Matrix 平台的消息通知
|
||||||
|
- 支持 Replicate 平台上的 mistral 模型(beta)
|
||||||
|
|
||||||
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
||||||
|
|
||||||
|
## 配置文件
|
||||||
|
|
||||||
|
默认情况下程序会使用当前目录下的 `config.yaml` 文件,您可以通过使用 `-config your-config.yaml` 参数指定配置文件路径。
|
||||||
|
|
||||||
|
以下是一个配置文件示例,你可以在 `config.sample.yaml` 文件中找到同样的内容
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
authorization: woshimima
|
||||||
|
|
||||||
|
# 默认超时时间,默认 120 秒,流式请求是 10 秒
|
||||||
|
timeout: 120
|
||||||
|
stream_timeout: 10
|
||||||
|
|
||||||
|
# 使用 sqlite 作为数据库储存请求记录
|
||||||
|
dbtype: sqlite
|
||||||
|
dbaddr: ./db.sqlite
|
||||||
|
|
||||||
|
# 使用 postgres 作为数据库储存请求记录
|
||||||
|
# dbtype: postgres
|
||||||
|
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
||||||
|
|
||||||
|
upstreams:
|
||||||
|
- sk: hahaha
|
||||||
|
endpoint: "https://localhost:8888/v1"
|
||||||
|
allow:
|
||||||
|
# whisper 等非 JSON API 识别不到 model,则使用 URL 路径作为模型名称
|
||||||
|
- /v1/audio/transcriptions
|
||||||
|
|
||||||
|
- sk: "secret_key_1"
|
||||||
|
endpoint: "https://api.openai.com/v2"
|
||||||
|
timeout: 120 # 请求超时时间,默认120秒
|
||||||
|
stream_timeout: 10 # 如果识别到 stream: true, 则使用该超时时间
|
||||||
|
allow: # 可选的模型白名单
|
||||||
|
- gpt-3.5-trubo
|
||||||
|
- gpt-3.5-trubo-0613
|
||||||
|
|
||||||
|
# 您可以设置很多个上游,程序将依次按顺序尝试
|
||||||
|
- sk: "secret_key_2"
|
||||||
|
endpoint: "https://api.openai.com/v1"
|
||||||
|
timeout: 30
|
||||||
|
deny:
|
||||||
|
- gpt-4
|
||||||
|
|
||||||
|
- sk: "key_for_replicate"
|
||||||
|
type: replicate
|
||||||
|
allow:
|
||||||
|
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置多个验证头
|
||||||
|
|
||||||
|
您可以使用英文逗号 `,` 分割多个验证头。每个验证头都是有效的,程序会记录每个请求使用的验证头
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
authorization: woshimima,iampassword
|
||||||
|
```
|
||||||
|
|
||||||
|
您也可以为上游单独设置验证头
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
authorization: woshimima,iampassword
|
||||||
|
upstreams:
|
||||||
|
- sk: key
|
||||||
|
authorization: woshimima
|
||||||
|
```
|
||||||
|
|
||||||
|
如此,只有携带 `woshimima` 验证头的用户可以使用该上游。
|
||||||
|
|
||||||
|
### 复杂配置示例
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
|
||||||
|
# 默认验证头
|
||||||
|
authorization: woshimima
|
||||||
|
|
||||||
|
upstreams:
|
||||||
|
|
||||||
|
# 允许所有人使用的文字转语音
|
||||||
|
- sk: xxx
|
||||||
|
endpoint: http://localhost:5000/v1
|
||||||
|
noauth: true
|
||||||
|
allow:
|
||||||
|
- /v1/audio/transcriptions
|
||||||
|
|
||||||
|
# guest 专用的 gpt-3.5-turbo-0125 模型
|
||||||
|
- sk:
|
||||||
|
endpoint: https://api.xxx.local/v1
|
||||||
|
authorization: guest
|
||||||
|
allow:
|
||||||
|
- gpt-3.5-turbo-0125
|
||||||
|
```
|
||||||
|
|
||||||
## 部署方法
|
## 部署方法
|
||||||
|
|
||||||
|
有两种推荐的部署方法:
|
||||||
|
|
||||||
|
1. 使用预先构建好的容器 `docker.io/heimoshuiyu/openai-api-route:latest`
|
||||||
|
2. 自行编译
|
||||||
|
|
||||||
|
### 使用容器运行
|
||||||
|
|
||||||
|
> 注意,如果您使用 sqlite 数据库,您可能还需要修改配置文件以将 SQLite 数据库文件放置在数据卷中。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run -d --name openai-api-route -v /path/to/config.yaml:/config.yaml docker.io/heimoshuiyu/openai-api-route:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
使用 Docker Compose
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '3'
|
||||||
|
services:
|
||||||
|
openai-api-route:
|
||||||
|
image: docker.io/heimoshuiyu/openai-api-route:latest
|
||||||
|
ports:
|
||||||
|
- 8888:8888
|
||||||
|
volumes:
|
||||||
|
- ./config.yaml:/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### 编译
|
### 编译
|
||||||
|
|
||||||
以下是编译和运行该负载均衡 API 的步骤:
|
以下是编译和运行该负载均衡 API 的步骤:
|
||||||
@@ -40,74 +159,11 @@
|
|||||||
./openai-api-route
|
./openai-api-route
|
||||||
```
|
```
|
||||||
|
|
||||||
默认情况下,API 将会在本地的 8888 端口进行监听。
|
## 模型允许与屏蔽列表
|
||||||
|
|
||||||
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
|
如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。
|
||||||
|
|
||||||
```
|
如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台
|
||||||
./openai-api-route -addr 0.0.0.0:8080
|
|
||||||
```
|
|
||||||
|
|
||||||
这将会将监听地址设置为 0.0.0.0:8080。
|
|
||||||
|
|
||||||
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件。
|
|
||||||
|
|
||||||
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
|
|
||||||
|
|
||||||
```
|
|
||||||
./openai-api-route -database /path/to/database.db
|
|
||||||
```
|
|
||||||
|
|
||||||
这将会将数据库地址设置为 `/path/to/database.db`。
|
|
||||||
|
|
||||||
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游,并使用 API 进行相关操作。
|
|
||||||
|
|
||||||
### 运行
|
|
||||||
|
|
||||||
以下是运行命令的用法:
|
|
||||||
|
|
||||||
```
|
|
||||||
Usage of ./openai-api-route:
|
|
||||||
-addr string
|
|
||||||
监听地址(默认为 ":8888")
|
|
||||||
-upstreams string
|
|
||||||
上游配置文件(默认为 "./upstreams.yaml")
|
|
||||||
-dbtype
|
|
||||||
数据库类型 (sqlite 或 postgres,默认为 sqlite)
|
|
||||||
-database string
|
|
||||||
数据库地址(默认为 "./db.sqlite")
|
|
||||||
如果数据库为 postgres ,则此值应 PostgreSQL DSN 格式
|
|
||||||
例如 "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
|
||||||
-list
|
|
||||||
列出所有上游
|
|
||||||
-noauth
|
|
||||||
不检查传入的授权头
|
|
||||||
```
|
|
||||||
|
|
||||||
以下是一个 `./upstreams.yaml` 文件配置示例
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
authorization: woshimima
|
|
||||||
|
|
||||||
# 使用 sqlite 作为数据库储存请求记录
|
|
||||||
dbtype: sqlite
|
|
||||||
dbaddr: ./db.sqlite
|
|
||||||
|
|
||||||
# 使用 postgres 作为数据库储存请求记录
|
|
||||||
# dbtype: postgres
|
|
||||||
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
|
||||||
|
|
||||||
upstreams:
|
|
||||||
- sk: "secret_key_1"
|
|
||||||
endpoint: "https://api.openai.com/v2"
|
|
||||||
- sk: "secret_key_2"
|
|
||||||
endpoint: "https://api.openai.com/v1"
|
|
||||||
timeout: 30
|
|
||||||
```
|
|
||||||
|
|
||||||
请注意,程序会根据情况修改 timeout 的值
|
|
||||||
|
|
||||||
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
|
|
||||||
|
|
||||||
## 超时策略
|
## 超时策略
|
||||||
|
|
||||||
@@ -128,7 +184,3 @@ upstreams:
|
|||||||
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
|
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
|
||||||
|
|
||||||
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。
|
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。
|
||||||
|
|
||||||
3. **大请求体**:如果请求体的大小超过 128KB(即 `len(inBody) > 1024*128`),超时时间将被设置为 20 秒。这考虑到了处理大型数据可能需要更长的时间。
|
|
||||||
|
|
||||||
4. **上游超时配置**:如果上游服务器在配置中指定了超时时间(`upstream.Timeout` 大于 0),服务将使用该值作为超时时间。这个值是以秒为单位的。
|
|
||||||
|
|||||||
28
auth.go
28
auth.go
@@ -2,32 +2,14 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleAuth(c *gin.Context) error {
|
func checkAuth(authorization string, config string) error {
|
||||||
var err error
|
for _, auth := range strings.Split(config, ",") {
|
||||||
|
if authorization == strings.Trim(auth, " ") {
|
||||||
authorization := c.Request.Header.Get("Authorization")
|
return nil
|
||||||
if !strings.HasPrefix(authorization, "Bearer") {
|
|
||||||
err = errors.New("authorization header should start with 'Bearer'")
|
|
||||||
c.AbortWithError(403, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
|
||||||
log.Println("Received authorization", authorization)
|
|
||||||
|
|
||||||
for _, auth := range strings.Split(config.Authorization, ",") {
|
|
||||||
if authorization != strings.Trim(auth, " ") {
|
|
||||||
err = errors.New("wrong authorization header")
|
|
||||||
c.AbortWithError(403, err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return errors.New("wrong authorization header")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ dbaddr: ./db.sqlite
|
|||||||
upstreams:
|
upstreams:
|
||||||
- sk: "secret_key_1"
|
- sk: "secret_key_1"
|
||||||
endpoint: "https://api.openai.com/v2"
|
endpoint: "https://api.openai.com/v2"
|
||||||
|
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||||
- sk: "secret_key_2"
|
- sk: "secret_key_2"
|
||||||
endpoint: "https://api.openai.com/v1"
|
endpoint: "https://api.openai.com/v1"
|
||||||
timeout: 30
|
timeout: 30
|
||||||
|
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||||
|
deny: ["gpt-4"] # 可选的模型黑名单
|
||||||
|
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
|
||||||
|
- sk: "key_for_replicate"
|
||||||
|
type: replicate
|
||||||
|
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]
|
||||||
|
|||||||
20
cors.go
Normal file
20
cors.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func sendCORSHeaders(c *gin.Context) {
|
||||||
|
log.Println("sendCORSHeaders")
|
||||||
|
if c.Writer.Header().Get("Access-Control-Allow-Origin") == "" {
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
if c.Writer.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||||
|
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||||
|
}
|
||||||
|
if c.Writer.Header().Get("Access-Control-Allow-Headers") == "" {
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||||
|
}
|
||||||
|
}
|
||||||
173
main.go
173
main.go
@@ -1,15 +1,21 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/penglongli/gin-metrics/ginmetrics"
|
"github.com/penglongli/gin-metrics/ginmetrics"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -23,11 +29,11 @@ func main() {
|
|||||||
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
|
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
log.Println("Service starting")
|
log.Println("[main]: Service starting")
|
||||||
|
|
||||||
// load all upstreams
|
// load all upstreams
|
||||||
config = readConfig(*configFile)
|
config = readConfig(*configFile)
|
||||||
log.Println("Load upstreams number:", len(config.Upstreams))
|
log.Println("[main]: Load upstreams number:", len(config.Upstreams))
|
||||||
|
|
||||||
// connect to database
|
// connect to database
|
||||||
var db *gorm.DB
|
var db *gorm.DB
|
||||||
@@ -38,17 +44,23 @@ func main() {
|
|||||||
PrepareStmt: true,
|
PrepareStmt: true,
|
||||||
SkipDefaultTransaction: true,
|
SkipDefaultTransaction: true,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("[main]: Error to connect sqlite database: %s", err)
|
||||||
|
}
|
||||||
case "postgres":
|
case "postgres":
|
||||||
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
|
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
|
||||||
PrepareStmt: true,
|
PrepareStmt: true,
|
||||||
SkipDefaultTransaction: true,
|
SkipDefaultTransaction: true,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("[main]: Error to connect postgres database: %s", err)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
log.Fatalf("Unsupported database type: '%s'", config.DBType)
|
log.Fatalf("[main]: Unsupported database type: '%s'", config.DBType)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AutoMigrate(&Record{})
|
db.AutoMigrate(&Record{})
|
||||||
log.Println("Auto migrate database done")
|
log.Println("[main]: Auto migrate database done")
|
||||||
|
|
||||||
if *listMode {
|
if *listMode {
|
||||||
fmt.Println("SK\tEndpoint")
|
fmt.Println("SK\tEndpoint")
|
||||||
@@ -66,6 +78,9 @@ func main() {
|
|||||||
m.SetMetricPath("/v1/metrics")
|
m.SetMetricPath("/v1/metrics")
|
||||||
m.Use(engine)
|
m.Use(engine)
|
||||||
|
|
||||||
|
// CORS middleware
|
||||||
|
// engine.Use(corsMiddleware())
|
||||||
|
|
||||||
// error handle middleware
|
// error handle middleware
|
||||||
engine.Use(func(c *gin.Context) {
|
engine.Use(func(c *gin.Context) {
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -74,68 +89,154 @@ func main() {
|
|||||||
}
|
}
|
||||||
errText := strings.Join(c.Errors.Errors(), "\n")
|
errText := strings.Join(c.Errors.Errors(), "\n")
|
||||||
c.JSON(-1, gin.H{
|
c.JSON(-1, gin.H{
|
||||||
"error": errText,
|
"openai-api-route error": errText,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// CORS handler
|
// CORS handler
|
||||||
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
|
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
|
||||||
header := ctx.Writer.Header()
|
// set cros header
|
||||||
header.Set("Access-Control-Allow-Origin", "*")
|
ctx.Header("Access-Control-Allow-Origin", "*")
|
||||||
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||||
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||||
ctx.AbortWithStatus(200)
|
ctx.AbortWithStatus(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
engine.POST("/v1/*any", func(c *gin.Context) {
|
engine.POST("/v1/*any", func(c *gin.Context) {
|
||||||
|
var err error
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
if config.Hostname != "" {
|
||||||
|
hostname = config.Hostname
|
||||||
|
}
|
||||||
record := Record{
|
record := Record{
|
||||||
IP: c.ClientIP(),
|
IP: c.ClientIP(),
|
||||||
|
Hostname: hostname,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
Authorization: c.Request.Header.Get("Authorization"),
|
Authorization: c.Request.Header.Get("Authorization"),
|
||||||
UserAgent: c.Request.Header.Get("User-Agent"),
|
UserAgent: c.Request.Header.Get("User-Agent"),
|
||||||
}
|
Model: c.Request.URL.Path,
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
log.Println("Error:", err)
|
|
||||||
c.AbortWithError(500, fmt.Errorf("%s", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// check authorization header
|
|
||||||
if !*noauth {
|
|
||||||
if handleAuth(c) != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for index, upstream := range config.Upstreams {
|
authorization := c.Request.Header.Get("Authorization")
|
||||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
if strings.HasPrefix(authorization, "Bearer") {
|
||||||
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
||||||
|
} else {
|
||||||
|
authorization = strings.Trim(authorization, " ")
|
||||||
|
log.Println("[auth] Warning: authorization header should start with 'Bearer'")
|
||||||
|
}
|
||||||
|
log.Println("Received authorization '" + authorization + "'")
|
||||||
|
|
||||||
|
availUpstreams := make([]OPENAI_UPSTREAM, 0)
|
||||||
|
for _, upstream := range config.Upstreams {
|
||||||
|
if upstream.SK == "" {
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*noauth && !upstream.Noauth {
|
||||||
|
if checkAuth(authorization, upstream.Authorization) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
availUpstreams = append(availUpstreams, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(availUpstreams) == 0 {
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
|
||||||
|
}
|
||||||
|
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
|
||||||
|
|
||||||
|
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||||
|
wrapedBody := false
|
||||||
|
|
||||||
|
for index, _upstream := range availUpstreams {
|
||||||
|
|
||||||
|
// copy
|
||||||
|
upstream := _upstream
|
||||||
|
record.UpstreamEndpoint = upstream.Endpoint
|
||||||
|
record.UpstreamSK = upstream.SK
|
||||||
|
|
||||||
shouldResponse := index == len(config.Upstreams)-1
|
shouldResponse := index == len(config.Upstreams)-1
|
||||||
|
|
||||||
if len(config.Upstreams) == 1 {
|
if len(availUpstreams) == 1 {
|
||||||
|
// [todo] copy problem
|
||||||
upstream.Timeout = 120
|
upstream.Timeout = 120
|
||||||
}
|
}
|
||||||
|
|
||||||
err = processRequest(c, &upstream, &record, shouldResponse)
|
// buffer for incoming request
|
||||||
if err != nil {
|
if !wrapedBody {
|
||||||
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
|
log.Println("[processRequest.begin]: wrap request body")
|
||||||
continue
|
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
|
||||||
|
wrapedBody = true
|
||||||
|
} else {
|
||||||
|
log.Println("[processRequest.begin]: reuse request body")
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
break
|
if upstream.Type == "replicate" {
|
||||||
|
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
|
||||||
|
} else if upstream.Type == "openai" {
|
||||||
|
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == http.ErrAbortHandler {
|
||||||
|
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
|
||||||
|
log.Println(abortErr)
|
||||||
|
record.Response += abortErr
|
||||||
|
record.Status = 500
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
|
||||||
|
|
||||||
|
// error process, break
|
||||||
|
if shouldResponse {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Record result:", record.Status, record.Response)
|
// parse and record request body
|
||||||
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
|
requestBodyBytes := bufIO.Bytes()
|
||||||
if db.Create(&record).Error != nil {
|
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
|
||||||
log.Println("Error to save record:", record)
|
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
|
||||||
|
record.Body = string(requestBodyBytes)
|
||||||
}
|
}
|
||||||
|
requestBody, err := ParseRequestBody(requestBodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("[processRequest.done]: Error to parse request body:", err)
|
||||||
|
} else {
|
||||||
|
record.Model = requestBody.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("[final]: Record result:", record.Status, record.Response)
|
||||||
|
record.ElapsedTime = time.Since(record.CreatedAt)
|
||||||
|
|
||||||
|
// async record request
|
||||||
|
go func() {
|
||||||
|
// encoder headers to record.Headers in json string
|
||||||
|
headers, _ := json.Marshal(c.Request.Header)
|
||||||
|
record.Headers = string(headers)
|
||||||
|
|
||||||
|
// turncate request if too long
|
||||||
|
log.Println("[async.record]: body length:", len(record.Body))
|
||||||
|
if db.Create(&record).Error != nil {
|
||||||
|
log.Println("[async.record]: Error to save record:", record)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if record.Status != 200 {
|
if record.Status != 200 {
|
||||||
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
|
errMessage := fmt.Sprintf("[result.error]: IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
|
||||||
go sendFeishuMessage(errMessage)
|
go sendFeishuMessage(errMessage)
|
||||||
go sendMatrixMessage(errMessage)
|
go sendMatrixMessage(errMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
282
process.go
282
process.go
@@ -8,216 +8,118 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
var errCtx error
|
|
||||||
|
|
||||||
record.UpstreamEndpoint = upstream.Endpoint
|
|
||||||
record.UpstreamSK = upstream.SK
|
|
||||||
record.Response = ""
|
|
||||||
// [TODO] record request body
|
|
||||||
|
|
||||||
// reverse proxy
|
// reverse proxy
|
||||||
remote, err := url.Parse(upstream.Endpoint)
|
remote, err := url.Parse(upstream.Endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
haveResponse := false
|
path := strings.TrimPrefix(c.Request.URL.Path, "/v1")
|
||||||
|
remote.Path = upstream.URL.Path + path
|
||||||
|
log.Println("[proxy.begin]:", remote)
|
||||||
|
log.Println("[proxy.begin]: shouldResposne:", shouldResponse)
|
||||||
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(remote)
|
client := &http.Client{}
|
||||||
proxy.Director = nil
|
request := &http.Request{}
|
||||||
var inBody []byte
|
request.ContentLength = c.Request.ContentLength
|
||||||
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
request.Method = c.Request.Method
|
||||||
in := proxyRequest.In
|
request.URL = remote
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
// process header
|
||||||
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
|
if upstream.KeepHeader {
|
||||||
|
request.Header = c.Request.Header
|
||||||
out := proxyRequest.Out
|
|
||||||
|
|
||||||
// read request body
|
|
||||||
inBody, err = io.ReadAll(in.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// record chat message from user
|
|
||||||
record.Body = string(inBody)
|
|
||||||
requestBody, requestBodyOK := ParseRequestBody(inBody)
|
|
||||||
// record if parse success
|
|
||||||
if requestBodyOK == nil {
|
|
||||||
record.Model = requestBody.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
// set timeout, default is 60 second
|
|
||||||
timeout := 60 * time.Second
|
|
||||||
if requestBodyOK == nil && requestBody.Stream {
|
|
||||||
timeout = 5 * time.Second
|
|
||||||
}
|
|
||||||
if len(inBody) > 1024*128 {
|
|
||||||
timeout = 20 * time.Second
|
|
||||||
}
|
|
||||||
if upstream.Timeout > 0 {
|
|
||||||
// convert upstream.Timeout(second) to nanosecond
|
|
||||||
timeout = time.Duration(upstream.Timeout) * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
// timeout out request
|
|
||||||
go func() {
|
|
||||||
time.Sleep(timeout)
|
|
||||||
if !haveResponse {
|
|
||||||
log.Println("Timeout upstream", upstream.Endpoint)
|
|
||||||
errCtx = errors.New("timeout")
|
|
||||||
if shouldResponse {
|
|
||||||
c.AbortWithError(502, errCtx)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
out.Body = io.NopCloser(bytes.NewReader(inBody))
|
|
||||||
|
|
||||||
out.Host = remote.Host
|
|
||||||
out.URL.Scheme = remote.Scheme
|
|
||||||
out.URL.Host = remote.Host
|
|
||||||
out.URL.Path = in.URL.Path
|
|
||||||
out.Header = http.Header{}
|
|
||||||
out.Header.Set("Host", remote.Host)
|
|
||||||
if upstream.SK == "asis" {
|
|
||||||
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
} else {
|
|
||||||
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
|
||||||
}
|
|
||||||
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
var buf bytes.Buffer
|
|
||||||
var contentType string
|
|
||||||
proxy.ModifyResponse = func(r *http.Response) error {
|
|
||||||
haveResponse = true
|
|
||||||
record.Status = r.StatusCode
|
|
||||||
if !shouldResponse && r.StatusCode != 200 {
|
|
||||||
log.Println("upstream return not 200 and should not response", r.StatusCode)
|
|
||||||
return errors.New("upstream return not 200 and should not response")
|
|
||||||
}
|
|
||||||
r.Header.Del("Access-Control-Allow-Origin")
|
|
||||||
r.Header.Del("Access-Control-Allow-Methods")
|
|
||||||
r.Header.Del("Access-Control-Allow-Headers")
|
|
||||||
r.Header.Set("Access-Control-Allow-Origin", "*")
|
|
||||||
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
|
||||||
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
|
||||||
|
|
||||||
if r.StatusCode != 200 {
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
record.Response = "failed to read response from upstream " + err.Error()
|
|
||||||
return errors.New(record.Response)
|
|
||||||
}
|
|
||||||
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
|
|
||||||
record.Status = r.StatusCode
|
|
||||||
return fmt.Errorf(record.Response)
|
|
||||||
}
|
|
||||||
// count success
|
|
||||||
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
|
|
||||||
contentType = r.Header.Get("content-type")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
haveResponse = true
|
|
||||||
record.ResponseTime = time.Now().Sub(record.CreatedAt)
|
|
||||||
log.Println("Error", err, upstream.SK, upstream.Endpoint)
|
|
||||||
|
|
||||||
errCtx = err
|
|
||||||
|
|
||||||
// abort to error handle
|
|
||||||
if shouldResponse {
|
|
||||||
c.AbortWithError(502, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("response is", r.Response)
|
|
||||||
|
|
||||||
if record.Status == 0 {
|
|
||||||
record.Status = 502
|
|
||||||
}
|
|
||||||
if record.Response == "" {
|
|
||||||
record.Response = err.Error()
|
|
||||||
}
|
|
||||||
if r.Response != nil {
|
|
||||||
record.Status = r.Response.StatusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
|
||||||
|
|
||||||
// return context error
|
|
||||||
if errCtx != nil {
|
|
||||||
// fix inrequest body
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
|
||||||
return errCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
|
||||||
if err != nil {
|
|
||||||
record.Response = "failed to read response from upstream " + err.Error()
|
|
||||||
log.Println(record.Response)
|
|
||||||
} else {
|
} else {
|
||||||
|
request.Header = http.Header{}
|
||||||
// record response
|
|
||||||
// stream mode
|
|
||||||
if strings.HasPrefix(contentType, "text/event-stream") {
|
|
||||||
for _, line := range strings.Split(string(resp), "\n") {
|
|
||||||
chunk := StreamModeChunk{}
|
|
||||||
line = strings.TrimPrefix(line, "data:")
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(line), &chunk)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(chunk.Choices) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
record.Response += chunk.Choices[0].Delta.Content
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(contentType, "application/json") {
|
|
||||||
var fetchResp FetchModeResponse
|
|
||||||
err := json.Unmarshal(resp, &fetchResp)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Error parsing fetch response:", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
|
|
||||||
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(fetchResp.Choices) == 0 {
|
|
||||||
log.Println("Error: fetch response choice length is 0")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
record.Response = fetchResp.Choices[0].Message.Content
|
|
||||||
} else {
|
|
||||||
log.Println("Unknown content type", contentType)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(record.Body) > 1024*512 {
|
// process header authorization
|
||||||
record.Body = ""
|
if upstream.SK == "asis" {
|
||||||
|
request.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
} else {
|
||||||
|
request.Header.Set("Authorization", "Bearer "+upstream.SK)
|
||||||
|
}
|
||||||
|
request.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
request.Header.Set("Host", remote.Host)
|
||||||
|
request.Header.Set("Content-Length", c.Request.Header.Get("Content-Length"))
|
||||||
|
|
||||||
|
request.Body = c.Request.Body
|
||||||
|
|
||||||
|
resp, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
body := []byte{}
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
body, _ = io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
return errors.New(err.Error() + " " + string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
record.Status = resp.StatusCode
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
record.Status = resp.StatusCode
|
||||||
|
errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", resp.Status, string(body))
|
||||||
|
log.Println(errRet)
|
||||||
|
return errRet
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy response header
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Header(k, v[0])
|
||||||
|
}
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
|
||||||
|
respBodyBuffer := bytes.NewBuffer(make([]byte, 0, 4*1024))
|
||||||
|
respBodyTeeReader := io.TeeReader(resp.Body, respBodyBuffer)
|
||||||
|
record.ResponseTime = time.Since(record.CreatedAt)
|
||||||
|
io.Copy(c.Writer, respBodyTeeReader)
|
||||||
|
record.ElapsedTime = time.Since(record.CreatedAt)
|
||||||
|
|
||||||
|
// parse and record response
|
||||||
|
if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
||||||
|
var fetchResp FetchModeResponse
|
||||||
|
err := json.NewDecoder(respBodyBuffer).Decode(&fetchResp)
|
||||||
|
if err == nil {
|
||||||
|
if len(fetchResp.Choices) > 0 {
|
||||||
|
record.Response = fetchResp.Choices[0].Message.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||||
|
lines := bytes.Split(respBodyBuffer.Bytes(), []byte("\n"))
|
||||||
|
for _, line := range lines {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
line = bytes.TrimPrefix(line, []byte("data:"))
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunk := StreamModeChunk{}
|
||||||
|
err = json.Unmarshal(line, &chunk)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("[proxy.parseChunkError]:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if len(chunk.Choices) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record.Response += chunk.Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text") {
|
||||||
|
body, _ := io.ReadAll(respBodyBuffer)
|
||||||
|
record.Response = string(body)
|
||||||
|
} else {
|
||||||
|
log.Println("[proxy.record]: Unknown content type", resp.Header.Get("Content-Type"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
type Record struct {
|
type Record struct {
|
||||||
ID int64 `gorm:"primaryKey,autoIncrement"`
|
ID int64 `gorm:"primaryKey,autoIncrement"`
|
||||||
|
Hostname string
|
||||||
UpstreamEndpoint string
|
UpstreamEndpoint string
|
||||||
UpstreamSK string
|
UpstreamSK string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
@@ -18,6 +19,7 @@ type Record struct {
|
|||||||
Status int
|
Status int
|
||||||
Authorization string // the autorization header send by client
|
Authorization string // the autorization header send by client
|
||||||
UserAgent string
|
UserAgent string
|
||||||
|
Headers string
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamModeChunk struct {
|
type StreamModeChunk struct {
|
||||||
|
|||||||
24
recovery.go
Normal file
24
recovery.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ServeHTTP(proxy *httputil.ReverseProxy, w gin.ResponseWriter, r *http.Request) (errReturn error) {
|
||||||
|
|
||||||
|
// recovery
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
log.Println("[serve.panic]: ", err)
|
||||||
|
errReturn = errors.New("[serve.panic]: Panic recover in reverse proxy serve HTTP")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
369
replicate.go
Normal file
369
replicate.go
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
|
||||||
|
|
||||||
|
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
|
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
||||||
|
if shouldResponse {
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(502, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
|
// read request body
|
||||||
|
inBody, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// record request body
|
||||||
|
|
||||||
|
// parse request body
|
||||||
|
inRequest := &OpenAIChatRequest{}
|
||||||
|
err = json.Unmarshal(inBody, inRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Model = inRequest.Model
|
||||||
|
|
||||||
|
// check allow model
|
||||||
|
if len(upstream.Allow) > 0 {
|
||||||
|
isAllow := false
|
||||||
|
for _, model := range upstream.Allow {
|
||||||
|
if model == inRequest.Model {
|
||||||
|
isAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isAllow {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: model not allow")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check block model
|
||||||
|
if len(upstream.Deny) > 0 {
|
||||||
|
for _, model := range upstream.Deny {
|
||||||
|
if model == inRequest.Model {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: model deny")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set url
|
||||||
|
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
|
||||||
|
log.Println("[processReplicateRequest]: model_url:", model_url)
|
||||||
|
|
||||||
|
// create request with default value
|
||||||
|
outRequest := &ReplicateModelRequest{
|
||||||
|
Stream: false,
|
||||||
|
Input: ReplicateModelRequestInput{
|
||||||
|
Prompt: "",
|
||||||
|
MaxNewTokens: 1024,
|
||||||
|
Temperature: 0.6,
|
||||||
|
Top_p: 0.9,
|
||||||
|
Top_k: 50,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
PromptTemplate: "{prompt}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy value from input request
|
||||||
|
outRequest.Stream = inRequest.Stream
|
||||||
|
outRequest.Input.Temperature = inRequest.Temperature
|
||||||
|
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
|
||||||
|
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||||
|
|
||||||
|
// render prompt
|
||||||
|
systemMessage := ""
|
||||||
|
userMessage := ""
|
||||||
|
assistantMessage := ""
|
||||||
|
for _, message := range inRequest.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
if systemMessage != "" {
|
||||||
|
systemMessage += "\n"
|
||||||
|
}
|
||||||
|
systemMessage += message.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
if userMessage != "" {
|
||||||
|
userMessage += "\n"
|
||||||
|
}
|
||||||
|
userMessage += message.Content
|
||||||
|
if systemMessage != "" {
|
||||||
|
userMessage = systemMessage + "\n" + userMessage
|
||||||
|
systemMessage = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "assistant" {
|
||||||
|
if assistantMessage != "" {
|
||||||
|
assistantMessage += "\n"
|
||||||
|
}
|
||||||
|
assistantMessage += message.Content
|
||||||
|
|
||||||
|
if outRequest.Input.Prompt != "" {
|
||||||
|
outRequest.Input.Prompt += "\n"
|
||||||
|
}
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
|
||||||
|
} else {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
userMessage = ""
|
||||||
|
assistantMessage = ""
|
||||||
|
}
|
||||||
|
// unknown role
|
||||||
|
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
|
||||||
|
}
|
||||||
|
// final user message
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
|
||||||
|
userMessage = ""
|
||||||
|
}
|
||||||
|
// final assistant message
|
||||||
|
if assistantMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
|
||||||
|
|
||||||
|
// send request
|
||||||
|
outBody, err := json.Marshal(outRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// http add headers
|
||||||
|
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||||
|
// send
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// read response body
|
||||||
|
outBody, err = io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse reponse body
|
||||||
|
outResponse := &ReplicateModelResponse{}
|
||||||
|
err = json.Unmarshal(outBody, outResponse)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if outResponse.Stream {
|
||||||
|
// get result
|
||||||
|
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
|
||||||
|
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
// send
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// get result by chunk
|
||||||
|
var buffer string = ""
|
||||||
|
var indexCount int64 = 0
|
||||||
|
for {
|
||||||
|
if !strings.Contains(buffer, "\n\n") {
|
||||||
|
// receive chunk
|
||||||
|
chunk := make([]byte, 1024)
|
||||||
|
length, err := resp.Body.Read(chunk)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if length == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||||
|
}
|
||||||
|
// add chunk to buffer
|
||||||
|
chunk = bytes.Trim(chunk, "\x00")
|
||||||
|
buffer += string(chunk)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// cut the first chunk by find index
|
||||||
|
index := strings.Index(buffer, "\n\n")
|
||||||
|
chunk := buffer[:index]
|
||||||
|
buffer = buffer[index+2:]
|
||||||
|
|
||||||
|
// trim line
|
||||||
|
chunk = strings.Trim(chunk, "\n")
|
||||||
|
|
||||||
|
// ignore hi
|
||||||
|
if !strings.Contains(chunk, "\n") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse chunk to ReplicateModelResultChunk object
|
||||||
|
chunkObj := &ReplicateModelResultChunk{}
|
||||||
|
lines := strings.Split(chunk, "\n")
|
||||||
|
// first line is event
|
||||||
|
chunkObj.Event = strings.TrimSpace(lines[0])
|
||||||
|
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
||||||
|
// second line is id
|
||||||
|
chunkObj.ID = strings.TrimSpace(lines[1])
|
||||||
|
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
||||||
|
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
|
||||||
|
// third line is data
|
||||||
|
chunkObj.Data = lines[2]
|
||||||
|
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
|
||||||
|
|
||||||
|
record.Response += chunkObj.Data
|
||||||
|
|
||||||
|
// done
|
||||||
|
if chunkObj.Event == "done" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
|
||||||
|
// create OpenAI response chunk
|
||||||
|
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||||
|
ID: "",
|
||||||
|
Model: outResponse.Model,
|
||||||
|
Choices: []OpenAIChatResponseChunkChoice{
|
||||||
|
{
|
||||||
|
Index: indexCount,
|
||||||
|
Delta: OpenAIChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: chunkObj.Data,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Writer.Flush()
|
||||||
|
indexCount += 1
|
||||||
|
}
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||||
|
ID: "",
|
||||||
|
Model: outResponse.Model,
|
||||||
|
Choices: []OpenAIChatResponseChunkChoice{
|
||||||
|
{
|
||||||
|
Index: indexCount,
|
||||||
|
Delta: OpenAIChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Writer.Flush()
|
||||||
|
indexCount += 1
|
||||||
|
record.Status = 200
|
||||||
|
return nil
|
||||||
|
|
||||||
|
} else {
|
||||||
|
var result *ReplicateModelResultGet
|
||||||
|
|
||||||
|
for {
|
||||||
|
// get result
|
||||||
|
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
|
||||||
|
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||||
|
// send
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||||
|
}
|
||||||
|
// get result
|
||||||
|
resultBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse reponse body
|
||||||
|
result = &ReplicateModelResultGet{}
|
||||||
|
err = json.Unmarshal(resultBody, result)
|
||||||
|
if err != nil {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
|
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Status == "processing" || result.Status == "starting" {
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// build openai resposne
|
||||||
|
openAIResult := &OpenAIChatResponse{
|
||||||
|
ID: result.ID,
|
||||||
|
Model: result.Model,
|
||||||
|
Choices: []OpenAIChatResponseChoice{},
|
||||||
|
Usage: OpenAIChatResponseUsage{
|
||||||
|
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
|
||||||
|
PromptTokens: result.Metrics.InputTokenCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: OpenAIChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: strings.Join(result.Output, ""),
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
})
|
||||||
|
|
||||||
|
record.Response = strings.Join(result.Output, "")
|
||||||
|
record.Status = 200
|
||||||
|
|
||||||
|
// gin return
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
c.JSON(200, openAIResult)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
157
structure.go
157
structure.go
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -9,15 +10,26 @@ import (
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Address string `yaml:"address"`
|
Address string `yaml:"address"`
|
||||||
|
Hostname string `yaml:"hostname"`
|
||||||
DBType string `yaml:"dbtype"`
|
DBType string `yaml:"dbtype"`
|
||||||
DBAddr string `yaml:"dbaddr"`
|
DBAddr string `yaml:"dbaddr"`
|
||||||
Authorization string `yaml:"authorization"`
|
Authorization string `yaml:"authorization"`
|
||||||
|
Timeout int64 `yaml:"timeout"`
|
||||||
|
StreamTimeout int64 `yaml:"stream_timeout"`
|
||||||
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
|
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
|
||||||
}
|
}
|
||||||
type OPENAI_UPSTREAM struct {
|
type OPENAI_UPSTREAM struct {
|
||||||
SK string `yaml:"sk"`
|
SK string `yaml:"sk"`
|
||||||
Endpoint string `yaml:"endpoint"`
|
Endpoint string `yaml:"endpoint"`
|
||||||
Timeout int64 `yaml:"timeout"`
|
Timeout int64 `yaml:"timeout"`
|
||||||
|
StreamTimeout int64 `yaml:"stream_timeout"`
|
||||||
|
Allow []string `yaml:"allow"`
|
||||||
|
Deny []string `yaml:"deny"`
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
KeepHeader bool `yaml:"keep_header"`
|
||||||
|
Authorization string `yaml:"authorization"`
|
||||||
|
Noauth bool `yaml:"noauth"`
|
||||||
|
URL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func readConfig(filepath string) Config {
|
func readConfig(filepath string) Config {
|
||||||
@@ -48,6 +60,145 @@ func readConfig(filepath string) Config {
|
|||||||
log.Println("DBAddr not set, use default value: ./db.sqlite")
|
log.Println("DBAddr not set, use default value: ./db.sqlite")
|
||||||
config.DBAddr = "./db.sqlite"
|
config.DBAddr = "./db.sqlite"
|
||||||
}
|
}
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
log.Println("Timeout not set, use default value: 120")
|
||||||
|
config.Timeout = 120
|
||||||
|
}
|
||||||
|
if config.StreamTimeout == 0 {
|
||||||
|
log.Println("StreamTimeout not set, use default value: 10")
|
||||||
|
config.StreamTimeout = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, upstream := range config.Upstreams {
|
||||||
|
// parse upstream endpoint URL
|
||||||
|
endpoint, err := url.Parse(upstream.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
||||||
|
}
|
||||||
|
config.Upstreams[i].URL = endpoint
|
||||||
|
if config.Upstreams[i].Type == "" {
|
||||||
|
config.Upstreams[i].Type = "openai"
|
||||||
|
}
|
||||||
|
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
|
||||||
|
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
|
||||||
|
}
|
||||||
|
// apply authorization from global config if not set
|
||||||
|
if config.Upstreams[i].Authorization == "" && !config.Upstreams[i].Noauth {
|
||||||
|
config.Upstreams[i].Authorization = config.Authorization
|
||||||
|
}
|
||||||
|
if config.Upstreams[i].Timeout == 0 {
|
||||||
|
config.Upstreams[i].Timeout = config.Timeout
|
||||||
|
}
|
||||||
|
if config.Upstreams[i].StreamTimeout == 0 {
|
||||||
|
config.Upstreams[i].StreamTimeout = config.StreamTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAIChatRequest struct {
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
MaxTokens int64 `json:"max_tokens"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
Messages []OpenAIChatRequestMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatRequestMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelRequest struct {
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Input ReplicateModelRequestInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelRequestInput struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxNewTokens int64 `json:"max_new_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
Top_p float64 `json:"top_p"`
|
||||||
|
Top_k int64 `json:"top_k"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
PromptTemplate string `json:"prompt_template"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
URLS ReplicateModelResponseURLS `json:"urls"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelResponseURLS struct {
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
Get string `json:"get"`
|
||||||
|
Stream string `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelResultGet struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Metrics ReplicateModelResultMetrics `json:"metrics"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelResultMetrics struct {
|
||||||
|
InputTokenCount int64 `json:"input_token_count"`
|
||||||
|
OutputTokenCount int64 `json:"output_token_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []OpenAIChatResponseChoice `json:"choices"`
|
||||||
|
Usage OpenAIChatResponseUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatResponseUsage struct {
|
||||||
|
PromptTokens int64 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int64 `json:"completion_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatResponseChoice struct {
|
||||||
|
Index int64 `json:"index"`
|
||||||
|
Message OpenAIChatMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateModelResultChunk struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatResponseChunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []OpenAIChatResponseChunkChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChatResponseChunkChoice struct {
|
||||||
|
Index int64 `json:"index"`
|
||||||
|
Delta OpenAIChatMessage `json:"delta"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user