Compare commits
40 Commits
7d93332e51
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
495f32610b
|
|||
|
45bba95f5d
|
|||
|
75ff8fbc2e
|
|||
| 66758e0008 | |||
|
40fc2067a5
|
|||
|
1a56101ca8
|
|||
| e373e3ac63 | |||
| 24a2e609f8 | |||
| e442303847 | |||
| 8b95fbb5da | |||
| 34aa4babc4 | |||
| e6ff1f5ca4 | |||
|
6b6f245e45
|
|||
| 995eea9d67 | |||
|
db7f0eb316
|
|||
|
990628b455
|
|||
|
e8b89fc41a
|
|||
|
46ee30ced7
|
|||
|
f2e32340e3
|
|||
|
ca386f8302
|
|||
|
3385f9af08
|
|||
|
8fa7fa79be
|
|||
|
49169452fe
|
|||
|
33f341026f
|
|||
|
b1a9d6b685
|
|||
|
1fc17daa35
|
|||
|
b1ab9c3d7b
|
|||
|
8672899a58
|
|||
|
3a59433f66
|
|||
|
873548a7d0
|
|||
|
2a2d907b0d
|
|||
|
2bbe98e694
|
|||
|
9fdbf259c0
|
|||
|
97926087bb
|
|||
|
fc5a8d55fa
|
|||
|
b1e3a97aad
|
|||
|
04a2e4c12d
|
|||
|
b8ebbed5d6
|
|||
|
412aefdacc
|
|||
|
2c3532f12f
|
@@ -1,3 +1,4 @@
|
||||
openai-api-route
|
||||
db.sqlite
|
||||
/config.yaml
|
||||
/.*
|
||||
|
||||
52
.gitlab-ci.yml
Normal file
52
.gitlab-ci.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
# To contribute improvements to CI/CD templates, please follow the Development guide at:
|
||||
# https://docs.gitlab.com/ee/development/cicd/templates.html
|
||||
# This specific template is located at:
|
||||
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Docker.gitlab-ci.yml
|
||||
|
||||
# Build a Docker image with CI/CD and push to the GitLab registry.
|
||||
# Docker-in-Docker documentation: https://docs.gitlab.com/ee/ci/docker/using_docker_build.html
|
||||
#
|
||||
# This template uses one generic job with conditional builds
|
||||
# for the default branch and all other (MR) branches.
|
||||
|
||||
docker-build:
|
||||
# Use the official docker image.
|
||||
image: docker:cli
|
||||
stage: build
|
||||
services:
|
||||
- docker:dind
|
||||
variables:
|
||||
CI_REGISTRY: registry.waykey.net:7999
|
||||
CI_REGISTRY_IMAGE: $CI_REGISTRY/spiderman/datamining/openai-api-route
|
||||
DOCKER_IMAGE_NAME: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
|
||||
before_script:
|
||||
- docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
|
||||
# All branches are tagged with $DOCKER_IMAGE_NAME (defaults to commit ref slug)
|
||||
# Default branch is also tagged with `latest`
|
||||
script:
|
||||
- docker build --pull -t "$DOCKER_IMAGE_NAME" .
|
||||
- docker push "$DOCKER_IMAGE_NAME"
|
||||
- |
|
||||
if [[ "$CI_COMMIT_BRANCH" == "$CI_DEFAULT_BRANCH" ]]; then
|
||||
docker tag "$DOCKER_IMAGE_NAME" "$CI_REGISTRY_IMAGE:latest"
|
||||
docker push "$CI_REGISTRY_IMAGE:latest"
|
||||
fi
|
||||
# Run this job in a branch where a Dockerfile exists
|
||||
rules:
|
||||
- if: $CI_COMMIT_BRANCH
|
||||
exists:
|
||||
- Dockerfile
|
||||
|
||||
deploy:
|
||||
environment: production
|
||||
image: kroniak/ssh-client
|
||||
stage: deploy
|
||||
before_script:
|
||||
- chmod 600 $CI_SSH_PRIVATE_KEY
|
||||
script:
|
||||
- ssh -o StrictHostKeyChecking=no -i $CI_SSH_PRIVATE_KEY root@192.168.1.13 "cd /mnt/data/srv/openai-api-route && podman-compose pull && podman-compose down && podman-compose up -d"
|
||||
rules:
|
||||
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
|
||||
exists:
|
||||
- Dockerfile
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.21 as builder
|
||||
FROM docker.io/golang:1.21 as builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
202
README.md
202
README.md
@@ -7,15 +7,134 @@
|
||||
- 自定义 Authorization 验证头
|
||||
- 支持所有类型的接口 (`/v1/*`)
|
||||
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
|
||||
- 按照定义顺序请求 OpenAI 上游
|
||||
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
|
||||
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
|
||||
- 请求出错时发送 飞书 或 Matrix 消息通知
|
||||
- 按照定义顺序请求 OpenAI 上游,出错或超时自动按顺序尝试下一个
|
||||
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用更短的超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
|
||||
- 有选择地记录请求内容、请求头、使用的上游、IP 地址、响应时间以及响应等内容。具体记录策略请参阅 [记录策略](#记录策略) 一节
|
||||
- 请求出错时发送 飞书 或 Matrix 平台的消息通知
|
||||
- 支持 Replicate 平台上的 mistral 模型(beta)
|
||||
|
||||
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
||||
|
||||
## 配置文件
|
||||
|
||||
默认情况下程序会使用当前目录下的 `config.yaml` 文件,您可以通过使用 `-config your-config.yaml` 参数指定配置文件路径。
|
||||
|
||||
以下是一个配置文件示例,你可以在 `config.sample.yaml` 文件中找到同样的内容
|
||||
|
||||
```yaml
|
||||
authorization: woshimima
|
||||
|
||||
# 默认超时时间,默认 120 秒,流式请求是 10 秒
|
||||
timeout: 120
|
||||
stream_timeout: 10
|
||||
|
||||
# 使用 sqlite 作为数据库储存请求记录
|
||||
dbtype: sqlite
|
||||
dbaddr: ./db.sqlite
|
||||
|
||||
# 使用 postgres 作为数据库储存请求记录
|
||||
# dbtype: postgres
|
||||
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
||||
|
||||
upstreams:
|
||||
- sk: hahaha
|
||||
endpoint: "https://localhost:8888/v1"
|
||||
allow:
|
||||
# whisper 等非 JSON API 识别不到 model,则使用 URL 路径作为模型名称
|
||||
- /v1/audio/transcriptions
|
||||
|
||||
- sk: "secret_key_1"
|
||||
endpoint: "https://api.openai.com/v2"
|
||||
timeout: 120 # 请求超时时间,默认120秒
|
||||
stream_timeout: 10 # 如果识别到 stream: true, 则使用该超时时间
|
||||
allow: # 可选的模型白名单
|
||||
- gpt-3.5-trubo
|
||||
- gpt-3.5-trubo-0613
|
||||
|
||||
# 您可以设置很多个上游,程序将依次按顺序尝试
|
||||
- sk: "secret_key_2"
|
||||
endpoint: "https://api.openai.com/v1"
|
||||
timeout: 30
|
||||
deny:
|
||||
- gpt-4
|
||||
|
||||
- sk: "key_for_replicate"
|
||||
type: replicate
|
||||
allow:
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
```
|
||||
|
||||
### 配置多个验证头
|
||||
|
||||
您可以使用英文逗号 `,` 分割多个验证头。每个验证头都是有效的,程序会记录每个请求使用的验证头
|
||||
|
||||
```yaml
|
||||
authorization: woshimima,iampassword
|
||||
```
|
||||
|
||||
您也可以为上游单独设置验证头
|
||||
|
||||
```yaml
|
||||
authorization: woshimima,iampassword
|
||||
upstreams:
|
||||
- sk: key
|
||||
authorization: woshimima
|
||||
```
|
||||
|
||||
如此,只有携带 `woshimima` 验证头的用户可以使用该上游。
|
||||
|
||||
### 复杂配置示例
|
||||
|
||||
```yaml
|
||||
|
||||
# 默认验证头
|
||||
authorization: woshimima
|
||||
|
||||
upstreams:
|
||||
|
||||
# 允许所有人使用的文字转语音
|
||||
- sk: xxx
|
||||
endpoint: http://localhost:5000/v1
|
||||
noauth: true
|
||||
allow:
|
||||
- /v1/audio/transcriptions
|
||||
|
||||
# guest 专用的 gpt-3.5-turbo-0125 模型
|
||||
- sk:
|
||||
endpoint: https://api.xxx.local/v1
|
||||
authorization: guest
|
||||
allow:
|
||||
- gpt-3.5-turbo-0125
|
||||
```
|
||||
|
||||
## 部署方法
|
||||
|
||||
有两种推荐的部署方法:
|
||||
|
||||
1. 使用预先构建好的容器 `docker.io/heimoshuiyu/openai-api-route:latest`
|
||||
2. 自行编译
|
||||
|
||||
### 使用容器运行
|
||||
|
||||
> 注意,如果您使用 sqlite 数据库,您可能还需要修改配置文件以将 SQLite 数据库文件放置在数据卷中。
|
||||
|
||||
```bash
|
||||
docker run -d --name openai-api-route -v /path/to/config.yaml:/config.yaml docker.io/heimoshuiyu/openai-api-route:latest
|
||||
```
|
||||
|
||||
使用 Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
openai-api-route:
|
||||
image: docker.io/heimoshuiyu/openai-api-route:latest
|
||||
ports:
|
||||
- 8888:8888
|
||||
volumes:
|
||||
- ./config.yaml:/config.yaml
|
||||
```
|
||||
|
||||
### 编译
|
||||
|
||||
以下是编译和运行该负载均衡 API 的步骤:
|
||||
@@ -40,74 +159,11 @@
|
||||
./openai-api-route
|
||||
```
|
||||
|
||||
默认情况下,API 将会在本地的 8888 端口进行监听。
|
||||
## 模型允许与屏蔽列表
|
||||
|
||||
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
|
||||
如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。
|
||||
|
||||
```
|
||||
./openai-api-route -addr 0.0.0.0:8080
|
||||
```
|
||||
|
||||
这将会将监听地址设置为 0.0.0.0:8080。
|
||||
|
||||
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件。
|
||||
|
||||
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
|
||||
|
||||
```
|
||||
./openai-api-route -database /path/to/database.db
|
||||
```
|
||||
|
||||
这将会将数据库地址设置为 `/path/to/database.db`。
|
||||
|
||||
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游,并使用 API 进行相关操作。
|
||||
|
||||
### 运行
|
||||
|
||||
以下是运行命令的用法:
|
||||
|
||||
```
|
||||
Usage of ./openai-api-route:
|
||||
-addr string
|
||||
监听地址(默认为 ":8888")
|
||||
-upstreams string
|
||||
上游配置文件(默认为 "./upstreams.yaml")
|
||||
-dbtype
|
||||
数据库类型 (sqlite 或 postgres,默认为 sqlite)
|
||||
-database string
|
||||
数据库地址(默认为 "./db.sqlite")
|
||||
如果数据库为 postgres ,则此值应 PostgreSQL DSN 格式
|
||||
例如 "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
||||
-list
|
||||
列出所有上游
|
||||
-noauth
|
||||
不检查传入的授权头
|
||||
```
|
||||
|
||||
以下是一个 `./upstreams.yaml` 文件配置示例
|
||||
|
||||
```yaml
|
||||
authorization: woshimima
|
||||
|
||||
# 使用 sqlite 作为数据库储存请求记录
|
||||
dbtype: sqlite
|
||||
dbaddr: ./db.sqlite
|
||||
|
||||
# 使用 postgres 作为数据库储存请求记录
|
||||
# dbtype: postgres
|
||||
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
||||
|
||||
upstreams:
|
||||
- sk: "secret_key_1"
|
||||
endpoint: "https://api.openai.com/v2"
|
||||
- sk: "secret_key_2"
|
||||
endpoint: "https://api.openai.com/v1"
|
||||
timeout: 30
|
||||
```
|
||||
|
||||
请注意,程序会根据情况修改 timeout 的值
|
||||
|
||||
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
|
||||
如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台
|
||||
|
||||
## 超时策略
|
||||
|
||||
@@ -127,8 +183,4 @@ upstreams:
|
||||
|
||||
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
|
||||
|
||||
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。
|
||||
|
||||
3. **大请求体**:如果请求体的大小超过 128KB(即 `len(inBody) > 1024*128`),超时时间将被设置为 20 秒。这考虑到了处理大型数据可能需要更长的时间。
|
||||
|
||||
4. **上游超时配置**:如果上游服务器在配置中指定了超时时间(`upstream.Timeout` 大于 0),服务将使用该值作为超时时间。这个值是以秒为单位的。
|
||||
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。
|
||||
28
auth.go
28
auth.go
@@ -2,32 +2,14 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func handleAuth(c *gin.Context) error {
|
||||
var err error
|
||||
|
||||
authorization := c.Request.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authorization, "Bearer") {
|
||||
err = errors.New("authorization header should start with 'Bearer'")
|
||||
c.AbortWithError(403, err)
|
||||
return err
|
||||
}
|
||||
|
||||
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
||||
log.Println("Received authorization", authorization)
|
||||
|
||||
for _, auth := range strings.Split(config.Authorization, ",") {
|
||||
if authorization != strings.Trim(auth, " ") {
|
||||
err = errors.New("wrong authorization header")
|
||||
c.AbortWithError(403, err)
|
||||
return err
|
||||
func checkAuth(authorization string, config string) error {
|
||||
for _, auth := range strings.Split(config, ",") {
|
||||
if authorization == strings.Trim(auth, " ") {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.New("wrong authorization header")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,13 @@ dbaddr: ./db.sqlite
|
||||
upstreams:
|
||||
- sk: "secret_key_1"
|
||||
endpoint: "https://api.openai.com/v2"
|
||||
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||
- sk: "secret_key_2"
|
||||
endpoint: "https://api.openai.com/v1"
|
||||
timeout: 30
|
||||
timeout: 30
|
||||
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||
deny: ["gpt-4"] # 可选的模型黑名单
|
||||
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
|
||||
- sk: "key_for_replicate"
|
||||
type: replicate
|
||||
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]
|
||||
|
||||
20
cors.go
Normal file
20
cors.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func sendCORSHeaders(c *gin.Context) {
|
||||
log.Println("sendCORSHeaders")
|
||||
if c.Writer.Header().Get("Access-Control-Allow-Origin") == "" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
if c.Writer.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||
}
|
||||
if c.Writer.Header().Get("Access-Control-Allow-Headers") == "" {
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||
}
|
||||
}
|
||||
173
main.go
173
main.go
@@ -1,15 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/penglongli/gin-metrics/ginmetrics"
|
||||
"gorm.io/driver/postgres"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -23,11 +29,11 @@ func main() {
|
||||
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
|
||||
flag.Parse()
|
||||
|
||||
log.Println("Service starting")
|
||||
log.Println("[main]: Service starting")
|
||||
|
||||
// load all upstreams
|
||||
config = readConfig(*configFile)
|
||||
log.Println("Load upstreams number:", len(config.Upstreams))
|
||||
log.Println("[main]: Load upstreams number:", len(config.Upstreams))
|
||||
|
||||
// connect to database
|
||||
var db *gorm.DB
|
||||
@@ -38,17 +44,23 @@ func main() {
|
||||
PrepareStmt: true,
|
||||
SkipDefaultTransaction: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("[main]: Error to connect sqlite database: %s", err)
|
||||
}
|
||||
case "postgres":
|
||||
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
|
||||
PrepareStmt: true,
|
||||
SkipDefaultTransaction: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("[main]: Error to connect postgres database: %s", err)
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Unsupported database type: '%s'", config.DBType)
|
||||
log.Fatalf("[main]: Unsupported database type: '%s'", config.DBType)
|
||||
}
|
||||
|
||||
db.AutoMigrate(&Record{})
|
||||
log.Println("Auto migrate database done")
|
||||
log.Println("[main]: Auto migrate database done")
|
||||
|
||||
if *listMode {
|
||||
fmt.Println("SK\tEndpoint")
|
||||
@@ -66,6 +78,9 @@ func main() {
|
||||
m.SetMetricPath("/v1/metrics")
|
||||
m.Use(engine)
|
||||
|
||||
// CORS middleware
|
||||
// engine.Use(corsMiddleware())
|
||||
|
||||
// error handle middleware
|
||||
engine.Use(func(c *gin.Context) {
|
||||
c.Next()
|
||||
@@ -74,68 +89,154 @@ func main() {
|
||||
}
|
||||
errText := strings.Join(c.Errors.Errors(), "\n")
|
||||
c.JSON(-1, gin.H{
|
||||
"error": errText,
|
||||
"openai-api-route error": errText,
|
||||
})
|
||||
})
|
||||
|
||||
// CORS handler
|
||||
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
|
||||
header := ctx.Writer.Header()
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||
// set cros header
|
||||
ctx.Header("Access-Control-Allow-Origin", "*")
|
||||
ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||
ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||
ctx.AbortWithStatus(200)
|
||||
})
|
||||
|
||||
engine.POST("/v1/*any", func(c *gin.Context) {
|
||||
var err error
|
||||
hostname, _ := os.Hostname()
|
||||
if config.Hostname != "" {
|
||||
hostname = config.Hostname
|
||||
}
|
||||
record := Record{
|
||||
IP: c.ClientIP(),
|
||||
Hostname: hostname,
|
||||
CreatedAt: time.Now(),
|
||||
Authorization: c.Request.Header.Get("Authorization"),
|
||||
UserAgent: c.Request.Header.Get("User-Agent"),
|
||||
}
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Println("Error:", err)
|
||||
c.AbortWithError(500, fmt.Errorf("%s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
// check authorization header
|
||||
if !*noauth {
|
||||
if handleAuth(c) != nil {
|
||||
return
|
||||
}
|
||||
Model: c.Request.URL.Path,
|
||||
}
|
||||
|
||||
for index, upstream := range config.Upstreams {
|
||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||
authorization := c.Request.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authorization, "Bearer") {
|
||||
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
||||
} else {
|
||||
authorization = strings.Trim(authorization, " ")
|
||||
log.Println("[auth] Warning: authorization header should start with 'Bearer'")
|
||||
}
|
||||
log.Println("Received authorization '" + authorization + "'")
|
||||
|
||||
availUpstreams := make([]OPENAI_UPSTREAM, 0)
|
||||
for _, upstream := range config.Upstreams {
|
||||
if upstream.SK == "" {
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
|
||||
continue
|
||||
}
|
||||
|
||||
if !*noauth && !upstream.Noauth {
|
||||
if checkAuth(authorization, upstream.Authorization) != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
availUpstreams = append(availUpstreams, upstream)
|
||||
}
|
||||
|
||||
if len(availUpstreams) == 0 {
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
|
||||
}
|
||||
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
|
||||
|
||||
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
wrapedBody := false
|
||||
|
||||
for index, _upstream := range availUpstreams {
|
||||
|
||||
// copy
|
||||
upstream := _upstream
|
||||
record.UpstreamEndpoint = upstream.Endpoint
|
||||
record.UpstreamSK = upstream.SK
|
||||
|
||||
shouldResponse := index == len(config.Upstreams)-1
|
||||
|
||||
if len(config.Upstreams) == 1 {
|
||||
if len(availUpstreams) == 1 {
|
||||
// [todo] copy problem
|
||||
upstream.Timeout = 120
|
||||
}
|
||||
|
||||
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||
if err != nil {
|
||||
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
|
||||
continue
|
||||
// buffer for incoming request
|
||||
if !wrapedBody {
|
||||
log.Println("[processRequest.begin]: wrap request body")
|
||||
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
|
||||
wrapedBody = true
|
||||
} else {
|
||||
log.Println("[processRequest.begin]: reuse request body")
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
|
||||
}
|
||||
|
||||
break
|
||||
if upstream.Type == "replicate" {
|
||||
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
|
||||
} else if upstream.Type == "openai" {
|
||||
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||
} else {
|
||||
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
|
||||
break
|
||||
}
|
||||
|
||||
if err == http.ErrAbortHandler {
|
||||
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
|
||||
log.Println(abortErr)
|
||||
record.Response += abortErr
|
||||
record.Status = 500
|
||||
break
|
||||
}
|
||||
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
|
||||
|
||||
// error process, break
|
||||
if shouldResponse {
|
||||
c.Header("Content-Type", "application/json")
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("Record result:", record.Status, record.Response)
|
||||
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
|
||||
if db.Create(&record).Error != nil {
|
||||
log.Println("Error to save record:", record)
|
||||
// parse and record request body
|
||||
requestBodyBytes := bufIO.Bytes()
|
||||
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
|
||||
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
|
||||
record.Body = string(requestBodyBytes)
|
||||
}
|
||||
requestBody, err := ParseRequestBody(requestBodyBytes)
|
||||
if err != nil {
|
||||
log.Println("[processRequest.done]: Error to parse request body:", err)
|
||||
} else {
|
||||
record.Model = requestBody.Model
|
||||
}
|
||||
|
||||
log.Println("[final]: Record result:", record.Status, record.Response)
|
||||
record.ElapsedTime = time.Since(record.CreatedAt)
|
||||
|
||||
// async record request
|
||||
go func() {
|
||||
// encoder headers to record.Headers in json string
|
||||
headers, _ := json.Marshal(c.Request.Header)
|
||||
record.Headers = string(headers)
|
||||
|
||||
// turncate request if too long
|
||||
log.Println("[async.record]: body length:", len(record.Body))
|
||||
if db.Create(&record).Error != nil {
|
||||
log.Println("[async.record]: Error to save record:", record)
|
||||
}
|
||||
}()
|
||||
|
||||
if record.Status != 200 {
|
||||
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
|
||||
errMessage := fmt.Sprintf("[result.error]: IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
|
||||
go sendFeishuMessage(errMessage)
|
||||
go sendMatrixMessage(errMessage)
|
||||
}
|
||||
|
||||
282
process.go
282
process.go
@@ -8,216 +8,118 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
var errCtx error
|
||||
|
||||
record.UpstreamEndpoint = upstream.Endpoint
|
||||
record.UpstreamSK = upstream.SK
|
||||
record.Response = ""
|
||||
// [TODO] record request body
|
||||
|
||||
// reverse proxy
|
||||
remote, err := url.Parse(upstream.Endpoint)
|
||||
if err != nil {
|
||||
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
|
||||
return err
|
||||
}
|
||||
|
||||
haveResponse := false
|
||||
path := strings.TrimPrefix(c.Request.URL.Path, "/v1")
|
||||
remote.Path = upstream.URL.Path + path
|
||||
log.Println("[proxy.begin]:", remote)
|
||||
log.Println("[proxy.begin]: shouldResposne:", shouldResponse)
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(remote)
|
||||
proxy.Director = nil
|
||||
var inBody []byte
|
||||
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
|
||||
in := proxyRequest.In
|
||||
client := &http.Client{}
|
||||
request := &http.Request{}
|
||||
request.ContentLength = c.Request.ContentLength
|
||||
request.Method = c.Request.Method
|
||||
request.URL = remote
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
|
||||
|
||||
out := proxyRequest.Out
|
||||
|
||||
// read request body
|
||||
inBody, err = io.ReadAll(in.Body)
|
||||
if err != nil {
|
||||
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// record chat message from user
|
||||
record.Body = string(inBody)
|
||||
requestBody, requestBodyOK := ParseRequestBody(inBody)
|
||||
// record if parse success
|
||||
if requestBodyOK == nil {
|
||||
record.Model = requestBody.Model
|
||||
}
|
||||
|
||||
// set timeout, default is 60 second
|
||||
timeout := 60 * time.Second
|
||||
if requestBodyOK == nil && requestBody.Stream {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
if len(inBody) > 1024*128 {
|
||||
timeout = 20 * time.Second
|
||||
}
|
||||
if upstream.Timeout > 0 {
|
||||
// convert upstream.Timeout(second) to nanosecond
|
||||
timeout = time.Duration(upstream.Timeout) * time.Second
|
||||
}
|
||||
|
||||
// timeout out request
|
||||
go func() {
|
||||
time.Sleep(timeout)
|
||||
if !haveResponse {
|
||||
log.Println("Timeout upstream", upstream.Endpoint)
|
||||
errCtx = errors.New("timeout")
|
||||
if shouldResponse {
|
||||
c.AbortWithError(502, errCtx)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
out.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||
|
||||
out.Host = remote.Host
|
||||
out.URL.Scheme = remote.Scheme
|
||||
out.URL.Host = remote.Host
|
||||
out.URL.Path = in.URL.Path
|
||||
out.Header = http.Header{}
|
||||
out.Header.Set("Host", remote.Host)
|
||||
if upstream.SK == "asis" {
|
||||
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
} else {
|
||||
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
||||
}
|
||||
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
var contentType string
|
||||
proxy.ModifyResponse = func(r *http.Response) error {
|
||||
haveResponse = true
|
||||
record.Status = r.StatusCode
|
||||
if !shouldResponse && r.StatusCode != 200 {
|
||||
log.Println("upstream return not 200 and should not response", r.StatusCode)
|
||||
return errors.New("upstream return not 200 and should not response")
|
||||
}
|
||||
r.Header.Del("Access-Control-Allow-Origin")
|
||||
r.Header.Del("Access-Control-Allow-Methods")
|
||||
r.Header.Del("Access-Control-Allow-Headers")
|
||||
r.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
|
||||
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
|
||||
|
||||
if r.StatusCode != 200 {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
record.Response = "failed to read response from upstream " + err.Error()
|
||||
return errors.New(record.Response)
|
||||
}
|
||||
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
|
||||
record.Status = r.StatusCode
|
||||
return fmt.Errorf(record.Response)
|
||||
}
|
||||
// count success
|
||||
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
|
||||
contentType = r.Header.Get("content-type")
|
||||
return nil
|
||||
}
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
haveResponse = true
|
||||
record.ResponseTime = time.Now().Sub(record.CreatedAt)
|
||||
log.Println("Error", err, upstream.SK, upstream.Endpoint)
|
||||
|
||||
errCtx = err
|
||||
|
||||
// abort to error handle
|
||||
if shouldResponse {
|
||||
c.AbortWithError(502, err)
|
||||
}
|
||||
|
||||
log.Println("response is", r.Response)
|
||||
|
||||
if record.Status == 0 {
|
||||
record.Status = 502
|
||||
}
|
||||
if record.Response == "" {
|
||||
record.Response = err.Error()
|
||||
}
|
||||
if r.Response != nil {
|
||||
record.Status = r.Response.StatusCode
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
|
||||
// return context error
|
||||
if errCtx != nil {
|
||||
// fix inrequest body
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
|
||||
return errCtx
|
||||
}
|
||||
|
||||
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||
if err != nil {
|
||||
record.Response = "failed to read response from upstream " + err.Error()
|
||||
log.Println(record.Response)
|
||||
// process header
|
||||
if upstream.KeepHeader {
|
||||
request.Header = c.Request.Header
|
||||
} else {
|
||||
|
||||
// record response
|
||||
// stream mode
|
||||
if strings.HasPrefix(contentType, "text/event-stream") {
|
||||
for _, line := range strings.Split(string(resp), "\n") {
|
||||
chunk := StreamModeChunk{}
|
||||
line = strings.TrimPrefix(line, "data:")
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(line), &chunk)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
record.Response += chunk.Choices[0].Delta.Content
|
||||
}
|
||||
} else if strings.HasPrefix(contentType, "application/json") {
|
||||
var fetchResp FetchModeResponse
|
||||
err := json.Unmarshal(resp, &fetchResp)
|
||||
if err != nil {
|
||||
log.Println("Error parsing fetch response:", err)
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
|
||||
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
|
||||
return nil
|
||||
}
|
||||
if len(fetchResp.Choices) == 0 {
|
||||
log.Println("Error: fetch response choice length is 0")
|
||||
return nil
|
||||
}
|
||||
record.Response = fetchResp.Choices[0].Message.Content
|
||||
} else {
|
||||
log.Println("Unknown content type", contentType)
|
||||
}
|
||||
request.Header = http.Header{}
|
||||
}
|
||||
|
||||
if len(record.Body) > 1024*512 {
|
||||
record.Body = ""
|
||||
// process header authorization
|
||||
if upstream.SK == "asis" {
|
||||
request.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
} else {
|
||||
request.Header.Set("Authorization", "Bearer "+upstream.SK)
|
||||
}
|
||||
request.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
request.Header.Set("Host", remote.Host)
|
||||
request.Header.Set("Content-Length", c.Request.Header.Get("Content-Length"))
|
||||
|
||||
request.Body = c.Request.Body
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
body := []byte{}
|
||||
if resp != nil && resp.Body != nil {
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
}
|
||||
return errors.New(err.Error() + " " + string(body))
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
record.Status = resp.StatusCode
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
record.Status = resp.StatusCode
|
||||
errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", resp.Status, string(body))
|
||||
log.Println(errRet)
|
||||
return errRet
|
||||
}
|
||||
|
||||
// copy response header
|
||||
for k, v := range resp.Header {
|
||||
c.Header(k, v[0])
|
||||
}
|
||||
sendCORSHeaders(c)
|
||||
|
||||
respBodyBuffer := bytes.NewBuffer(make([]byte, 0, 4*1024))
|
||||
respBodyTeeReader := io.TeeReader(resp.Body, respBodyBuffer)
|
||||
record.ResponseTime = time.Since(record.CreatedAt)
|
||||
io.Copy(c.Writer, respBodyTeeReader)
|
||||
record.ElapsedTime = time.Since(record.CreatedAt)
|
||||
|
||||
// parse and record response
|
||||
if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
||||
var fetchResp FetchModeResponse
|
||||
err := json.NewDecoder(respBodyBuffer).Decode(&fetchResp)
|
||||
if err == nil {
|
||||
if len(fetchResp.Choices) > 0 {
|
||||
record.Response = fetchResp.Choices[0].Message.Content
|
||||
}
|
||||
}
|
||||
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||
lines := bytes.Split(respBodyBuffer.Bytes(), []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
line = bytes.TrimPrefix(line, []byte("data:"))
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
chunk := StreamModeChunk{}
|
||||
err = json.Unmarshal(line, &chunk)
|
||||
if err != nil {
|
||||
log.Println("[proxy.parseChunkError]:", err)
|
||||
break
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
record.Response += chunk.Choices[0].Delta.Content
|
||||
}
|
||||
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text") {
|
||||
body, _ := io.ReadAll(respBodyBuffer)
|
||||
record.Response = string(body)
|
||||
} else {
|
||||
log.Println("[proxy.record]: Unknown content type", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
type Record struct {
|
||||
ID int64 `gorm:"primaryKey,autoIncrement"`
|
||||
Hostname string
|
||||
UpstreamEndpoint string
|
||||
UpstreamSK string
|
||||
CreatedAt time.Time
|
||||
@@ -18,6 +19,7 @@ type Record struct {
|
||||
Status int
|
||||
Authorization string // the autorization header send by client
|
||||
UserAgent string
|
||||
Headers string
|
||||
}
|
||||
|
||||
type StreamModeChunk struct {
|
||||
|
||||
24
recovery.go
Normal file
24
recovery.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ServeHTTP(proxy *httputil.ReverseProxy, w gin.ResponseWriter, r *http.Request) (errReturn error) {
|
||||
|
||||
// recovery
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Println("[serve.panic]: ", err)
|
||||
errReturn = errors.New("[serve.panic]: Panic recover in reverse proxy serve HTTP")
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
369
replicate.go
Normal file
369
replicate.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
|
||||
|
||||
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
||||
if shouldResponse {
|
||||
sendCORSHeaders(c)
|
||||
if err != nil {
|
||||
c.AbortWithError(502, err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
// read request body
|
||||
inBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
|
||||
}
|
||||
|
||||
// record request body
|
||||
|
||||
// parse request body
|
||||
inRequest := &OpenAIChatRequest{}
|
||||
err = json.Unmarshal(inBody, inRequest)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
|
||||
}
|
||||
|
||||
record.Model = inRequest.Model
|
||||
|
||||
// check allow model
|
||||
if len(upstream.Allow) > 0 {
|
||||
isAllow := false
|
||||
for _, model := range upstream.Allow {
|
||||
if model == inRequest.Model {
|
||||
isAllow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAllow {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: model not allow")
|
||||
}
|
||||
}
|
||||
// check block model
|
||||
if len(upstream.Deny) > 0 {
|
||||
for _, model := range upstream.Deny {
|
||||
if model == inRequest.Model {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: model deny")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set url
|
||||
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
|
||||
log.Println("[processReplicateRequest]: model_url:", model_url)
|
||||
|
||||
// create request with default value
|
||||
outRequest := &ReplicateModelRequest{
|
||||
Stream: false,
|
||||
Input: ReplicateModelRequestInput{
|
||||
Prompt: "",
|
||||
MaxNewTokens: 1024,
|
||||
Temperature: 0.6,
|
||||
Top_p: 0.9,
|
||||
Top_k: 50,
|
||||
FrequencyPenalty: 0.0,
|
||||
PresencePenalty: 0.0,
|
||||
PromptTemplate: "{prompt}",
|
||||
},
|
||||
}
|
||||
|
||||
// copy value from input request
|
||||
outRequest.Stream = inRequest.Stream
|
||||
outRequest.Input.Temperature = inRequest.Temperature
|
||||
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
|
||||
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||
|
||||
// render prompt
|
||||
systemMessage := ""
|
||||
userMessage := ""
|
||||
assistantMessage := ""
|
||||
for _, message := range inRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
if systemMessage != "" {
|
||||
systemMessage += "\n"
|
||||
}
|
||||
systemMessage += message.Content
|
||||
continue
|
||||
}
|
||||
if message.Role == "user" {
|
||||
if userMessage != "" {
|
||||
userMessage += "\n"
|
||||
}
|
||||
userMessage += message.Content
|
||||
if systemMessage != "" {
|
||||
userMessage = systemMessage + "\n" + userMessage
|
||||
systemMessage = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
if message.Role == "assistant" {
|
||||
if assistantMessage != "" {
|
||||
assistantMessage += "\n"
|
||||
}
|
||||
assistantMessage += message.Content
|
||||
|
||||
if outRequest.Input.Prompt != "" {
|
||||
outRequest.Input.Prompt += "\n"
|
||||
}
|
||||
if userMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
|
||||
} else {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||
}
|
||||
userMessage = ""
|
||||
assistantMessage = ""
|
||||
}
|
||||
// unknown role
|
||||
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
|
||||
}
|
||||
// final user message
|
||||
if userMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
|
||||
userMessage = ""
|
||||
}
|
||||
// final assistant message
|
||||
if assistantMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||
}
|
||||
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
|
||||
|
||||
// send request
|
||||
outBody, err := json.Marshal(outRequest)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
|
||||
}
|
||||
|
||||
// http add headers
|
||||
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
|
||||
}
|
||||
|
||||
// read response body
|
||||
outBody, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
|
||||
// parse reponse body
|
||||
outResponse := &ReplicateModelResponse{}
|
||||
err = json.Unmarshal(outBody, outResponse)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||
}
|
||||
|
||||
if outResponse.Stream {
|
||||
// get result
|
||||
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
|
||||
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||
}
|
||||
|
||||
// get result by chunk
|
||||
var buffer string = ""
|
||||
var indexCount int64 = 0
|
||||
for {
|
||||
if !strings.Contains(buffer, "\n\n") {
|
||||
// receive chunk
|
||||
chunk := make([]byte, 1024)
|
||||
length, err := resp.Body.Read(chunk)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
// add chunk to buffer
|
||||
chunk = bytes.Trim(chunk, "\x00")
|
||||
buffer += string(chunk)
|
||||
continue
|
||||
}
|
||||
|
||||
// cut the first chunk by find index
|
||||
index := strings.Index(buffer, "\n\n")
|
||||
chunk := buffer[:index]
|
||||
buffer = buffer[index+2:]
|
||||
|
||||
// trim line
|
||||
chunk = strings.Trim(chunk, "\n")
|
||||
|
||||
// ignore hi
|
||||
if !strings.Contains(chunk, "\n") {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse chunk to ReplicateModelResultChunk object
|
||||
chunkObj := &ReplicateModelResultChunk{}
|
||||
lines := strings.Split(chunk, "\n")
|
||||
// first line is event
|
||||
chunkObj.Event = strings.TrimSpace(lines[0])
|
||||
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
||||
// second line is id
|
||||
chunkObj.ID = strings.TrimSpace(lines[1])
|
||||
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
||||
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
|
||||
// third line is data
|
||||
chunkObj.Data = lines[2]
|
||||
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
|
||||
|
||||
record.Response += chunkObj.Data
|
||||
|
||||
// done
|
||||
if chunkObj.Event == "done" {
|
||||
break
|
||||
}
|
||||
|
||||
sendCORSHeaders(c)
|
||||
|
||||
// create OpenAI response chunk
|
||||
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||
ID: "",
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChunkChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Delta: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: chunkObj.Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Writer.Flush()
|
||||
indexCount += 1
|
||||
}
|
||||
sendCORSHeaders(c)
|
||||
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||
ID: "",
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChunkChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Delta: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Writer.Flush()
|
||||
indexCount += 1
|
||||
record.Status = 200
|
||||
return nil
|
||||
|
||||
} else {
|
||||
var result *ReplicateModelResultGet
|
||||
|
||||
for {
|
||||
// get result
|
||||
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
|
||||
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||
}
|
||||
// get result
|
||||
resultBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
|
||||
// parse reponse body
|
||||
result = &ReplicateModelResultGet{}
|
||||
err = json.Unmarshal(resultBody, result)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||
}
|
||||
|
||||
if result.Status == "processing" || result.Status == "starting" {
|
||||
time.Sleep(3 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// build openai resposne
|
||||
openAIResult := &OpenAIChatResponse{
|
||||
ID: result.ID,
|
||||
Model: result.Model,
|
||||
Choices: []OpenAIChatResponseChoice{},
|
||||
Usage: OpenAIChatResponseUsage{
|
||||
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
|
||||
PromptTokens: result.Metrics.InputTokenCount,
|
||||
},
|
||||
}
|
||||
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
|
||||
Index: 0,
|
||||
Message: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: strings.Join(result.Output, ""),
|
||||
},
|
||||
FinishReason: "stop",
|
||||
})
|
||||
|
||||
record.Response = strings.Join(result.Output, "")
|
||||
record.Status = 200
|
||||
|
||||
// gin return
|
||||
sendCORSHeaders(c)
|
||||
c.JSON(200, openAIResult)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
157
structure.go
157
structure.go
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -9,15 +10,26 @@ import (
|
||||
|
||||
type Config struct {
|
||||
Address string `yaml:"address"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
DBType string `yaml:"dbtype"`
|
||||
DBAddr string `yaml:"dbaddr"`
|
||||
Authorization string `yaml:"authorization"`
|
||||
Timeout int64 `yaml:"timeout"`
|
||||
StreamTimeout int64 `yaml:"stream_timeout"`
|
||||
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
|
||||
}
|
||||
type OPENAI_UPSTREAM struct {
|
||||
SK string `yaml:"sk"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Timeout int64 `yaml:"timeout"`
|
||||
SK string `yaml:"sk"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Timeout int64 `yaml:"timeout"`
|
||||
StreamTimeout int64 `yaml:"stream_timeout"`
|
||||
Allow []string `yaml:"allow"`
|
||||
Deny []string `yaml:"deny"`
|
||||
Type string `yaml:"type"`
|
||||
KeepHeader bool `yaml:"keep_header"`
|
||||
Authorization string `yaml:"authorization"`
|
||||
Noauth bool `yaml:"noauth"`
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
func readConfig(filepath string) Config {
|
||||
@@ -48,6 +60,145 @@ func readConfig(filepath string) Config {
|
||||
log.Println("DBAddr not set, use default value: ./db.sqlite")
|
||||
config.DBAddr = "./db.sqlite"
|
||||
}
|
||||
if config.Timeout == 0 {
|
||||
log.Println("Timeout not set, use default value: 120")
|
||||
config.Timeout = 120
|
||||
}
|
||||
if config.StreamTimeout == 0 {
|
||||
log.Println("StreamTimeout not set, use default value: 10")
|
||||
config.StreamTimeout = 10
|
||||
}
|
||||
|
||||
for i, upstream := range config.Upstreams {
|
||||
// parse upstream endpoint URL
|
||||
endpoint, err := url.Parse(upstream.Endpoint)
|
||||
if err != nil {
|
||||
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
||||
}
|
||||
config.Upstreams[i].URL = endpoint
|
||||
if config.Upstreams[i].Type == "" {
|
||||
config.Upstreams[i].Type = "openai"
|
||||
}
|
||||
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
|
||||
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
|
||||
}
|
||||
// apply authorization from global config if not set
|
||||
if config.Upstreams[i].Authorization == "" && !config.Upstreams[i].Noauth {
|
||||
config.Upstreams[i].Authorization = config.Authorization
|
||||
}
|
||||
if config.Upstreams[i].Timeout == 0 {
|
||||
config.Upstreams[i].Timeout = config.Timeout
|
||||
}
|
||||
if config.Upstreams[i].StreamTimeout == 0 {
|
||||
config.Upstreams[i].StreamTimeout = config.StreamTimeout
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
type OpenAIChatRequest struct {
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Messages []OpenAIChatRequestMessage
|
||||
}
|
||||
|
||||
type OpenAIChatRequestMessage struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type ReplicateModelRequest struct {
|
||||
Stream bool `json:"stream"`
|
||||
Input ReplicateModelRequestInput `json:"input"`
|
||||
}
|
||||
|
||||
type ReplicateModelRequestInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
MaxNewTokens int64 `json:"max_new_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Top_p float64 `json:"top_p"`
|
||||
Top_k int64 `json:"top_k"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PromptTemplate string `json:"prompt_template"`
|
||||
}
|
||||
|
||||
type ReplicateModelResponse struct {
|
||||
Model string `json:"model"`
|
||||
Version string `json:"version"`
|
||||
Stream bool `json:"stream"`
|
||||
Error string `json:"error"`
|
||||
URLS ReplicateModelResponseURLS `json:"urls"`
|
||||
}
|
||||
|
||||
type ReplicateModelResponseURLS struct {
|
||||
Cancel string `json:"cancel"`
|
||||
Get string `json:"get"`
|
||||
Stream string `json:"stream"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultGet struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Version string `json:"version"`
|
||||
Output []string `json:"output"`
|
||||
Error string `json:"error"`
|
||||
Metrics ReplicateModelResultMetrics `json:"metrics"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultMetrics struct {
|
||||
InputTokenCount int64 `json:"input_token_count"`
|
||||
OutputTokenCount int64 `json:"output_token_count"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChatResponseChoice `json:"choices"`
|
||||
Usage OpenAIChatResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseUsage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChoice struct {
|
||||
Index int64 `json:"index"`
|
||||
Message OpenAIChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAIChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultChunk struct {
|
||||
Event string `json:"event"`
|
||||
ID string `json:"id"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChatResponseChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChunkChoice struct {
|
||||
Index int64 `json:"index"`
|
||||
Delta OpenAIChatMessage `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user