46 Commits

Author SHA1 Message Date
495f32610b refactor: pipe the read and write process
this refactor simplify the process logic and fix several bugs and
performance issue.

bug fixed:
- cors headers not being sent in some situation
performance:
- perform upstream reqeust while clien is uploading content
2024-05-27 14:47:00 +08:00
45bba95f5d Merge remote-tracking branch 'comp/master' 2024-04-09 19:07:45 +08:00
75ff8fbc2e Refactor CORS handling and remove response's CORS headers 2024-04-09 19:07:28 +08:00
66758e0008 更新.gitlab-ci.yml文件 2024-04-08 10:29:08 +00:00
40fc2067a5 Merge remote-tracking branch 'comp/master' 2024-04-08 18:22:10 +08:00
1a56101ca8 add dockerignore 2024-04-08 18:21:55 +08:00
e373e3ac63 更新.gitlab-ci.yml文件 2024-04-08 07:48:05 +00:00
24a2e609f8 更新.gitlab-ci.yml文件 2024-04-08 05:11:47 +00:00
e442303847 更新.gitlab-ci.yml文件 2024-04-08 03:52:19 +00:00
8b95fbb5da 更新.gitlab-ci.yml文件 2024-04-08 03:40:32 +00:00
34aa4babc4 更新.gitlab-ci.yml文件 2024-04-08 03:21:24 +00:00
e6ff1f5ca4 更新.gitlab-ci.yml文件 2024-04-08 03:10:24 +00:00
6b6f245e45 Update README.md with complex configuration example 2024-04-08 11:04:29 +08:00
995eea9d67 Update README.md 2024-02-18 17:28:16 +08:00
db7f0eb316 timeout 2024-02-18 16:45:37 +08:00
990628b455 bro gooooo 2024-02-17 00:30:27 +08:00
e8b89fc41a record all json request body 2024-02-16 22:42:48 +08:00
46ee30ced7 use path as default model name 2024-02-16 22:31:32 +08:00
f2e32340e3 fix typo 2024-02-16 17:47:40 +08:00
ca386f8302 record whisper response 2024-02-15 16:52:33 +08:00
3385f9af08 fix: replicate mistral prompt 2024-01-23 16:38:45 +08:00
8fa7fa79be less: replicate log 2024-01-23 15:56:48 +08:00
49169452fe fix: read config upstream type default value 2024-01-23 15:52:46 +08:00
33f341026f fix: replicate response format 2024-01-23 15:43:48 +08:00
b1a9d6b685 add: support replicate 2024-01-23 15:20:22 +08:00
1fc17daa35 deny allow list if unknown model 2024-01-21 15:37:26 +08:00
b1ab9c3d7b add allow & deny model list, fix cors on error 2024-01-16 17:11:25 +08:00
8672899a58 change log format 2024-01-16 16:23:27 +08:00
3a59433f66 append error 2024-01-16 12:03:23 +08:00
873548a7d0 fix: body larger than 1024*128 2024-01-16 11:32:06 +08:00
2a2d907b0d fix: podman image build 2024-01-11 16:43:40 +08:00
2bbe98e694 fix duplicated cors headers 2024-01-04 19:03:58 +08:00
9fdbf259c0 truncate requset body if too long 2024-01-04 18:41:26 +08:00
97926087bb async record request 2024-01-04 18:37:01 +08:00
fc5a8d55fa fix: process error content type 2023-12-22 15:08:24 +08:00
b1e3a97aad fix: cors and content-type on error 2023-12-22 14:24:11 +08:00
04a2e4c12d use cors middleware 2023-12-22 13:23:41 +08:00
b8ebbed5d6 fix: recognize '/v1' prefix 2023-12-07 00:51:21 +08:00
412aefdacc fix: record resp time 2023-11-30 10:24:00 +08:00
2c3532f12f add hostname 2023-11-30 10:18:01 +08:00
7d93332e51 fix: record body json 2023-11-30 10:06:10 +08:00
0dbd898532 update config yaml 2023-11-29 14:40:20 +08:00
7b74818676 update config to file 2023-11-29 14:31:11 +08:00
0a2a7376f1 update dep 2023-11-28 14:36:14 +08:00
44a966e6f4 clean code 2023-11-28 14:04:42 +08:00
87244d4dc2 update docs 2023-11-28 10:42:06 +08:00
15 changed files with 1042 additions and 383 deletions

View File

@@ -1,3 +1,4 @@
openai-api-route
db.sqlite
/config.yaml
/.*

52
.gitlab-ci.yml Normal file
View File

@@ -0,0 +1,52 @@
# To contribute improvements to CI/CD templates, please follow the Development guide at:
# https://docs.gitlab.com/ee/development/cicd/templates.html
# This specific template is located at:
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Docker.gitlab-ci.yml
# Build a Docker image with CI/CD and push to the GitLab registry.
# Docker-in-Docker documentation: https://docs.gitlab.com/ee/ci/docker/using_docker_build.html
#
# This template uses one generic job with conditional builds
# for the default branch and all other (MR) branches.
docker-build:
# Use the official docker image.
image: docker:cli
stage: build
services:
- docker:dind
variables:
CI_REGISTRY: registry.waykey.net:7999
CI_REGISTRY_IMAGE: $CI_REGISTRY/spiderman/datamining/openai-api-route
DOCKER_IMAGE_NAME: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
before_script:
- docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY
# All branches are tagged with $DOCKER_IMAGE_NAME (defaults to commit ref slug)
# Default branch is also tagged with `latest`
script:
- docker build --pull -t "$DOCKER_IMAGE_NAME" .
- docker push "$DOCKER_IMAGE_NAME"
- |
if [[ "$CI_COMMIT_BRANCH" == "$CI_DEFAULT_BRANCH" ]]; then
docker tag "$DOCKER_IMAGE_NAME" "$CI_REGISTRY_IMAGE:latest"
docker push "$CI_REGISTRY_IMAGE:latest"
fi
# Run this job in a branch where a Dockerfile exists
rules:
- if: $CI_COMMIT_BRANCH
exists:
- Dockerfile
deploy:
environment: production
image: kroniak/ssh-client
stage: deploy
before_script:
- chmod 600 $CI_SSH_PRIVATE_KEY
script:
- ssh -o StrictHostKeyChecking=no -i $CI_SSH_PRIVATE_KEY root@192.168.1.13 "cd /mnt/data/srv/openai-api-route && podman-compose pull && podman-compose down && podman-compose up -d"
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
exists:
- Dockerfile

View File

@@ -1,4 +1,4 @@
FROM golang:1.21 as builder
FROM docker.io/golang:1.21 as builder
WORKDIR /app
@@ -14,4 +14,4 @@ FROM alpine
COPY --from=builder /app/openai-api-route /openai-api-route
ENTRYPOINT ["/openai-api-route"]
ENTRYPOINT ["/openai-api-route"]

187
README.md
View File

@@ -7,15 +7,134 @@
- 自定义 Authorization 验证头
- 支持所有类型的接口 (`/v1/*`)
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
- 按照定义顺序请求 OpenAI 上游
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。对于其他请求使用 60 秒超时。
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
- 请求出错时发送 飞书 或 Matrix 消息通知
- 按照定义顺序请求 OpenAI 上游,出错或超时自动按顺序尝试下一个
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用更短的超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
- 有选择地记录请求内容、请求头、使用的上游、IP 地址、响应时间以及响应等内容。具体记录策略请参阅 [记录策略](#记录策略) 一节
- 请求出错时发送 飞书 或 Matrix 平台的消息通知
- 支持 Replicate 平台上的 mistral 模型beta
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
## 配置文件
默认情况下程序会使用当前目录下的 `config.yaml` 文件,您可以通过使用 `-config your-config.yaml` 参数指定配置文件路径。
以下是一个配置文件示例,你可以在 `config.sample.yaml` 文件中找到同样的内容
```yaml
authorization: woshimima
# 默认超时时间,默认 120 秒,流式请求是 10 秒
timeout: 120
stream_timeout: 10
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: hahaha
endpoint: "https://localhost:8888/v1"
allow:
# whisper 等非 JSON API 识别不到 model则使用 URL 路径作为模型名称
- /v1/audio/transcriptions
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
timeout: 120 # 请求超时时间默认120秒
stream_timeout: 10 # 如果识别到 stream: true, 则使用该超时时间
allow: # 可选的模型白名单
- gpt-3.5-trubo
- gpt-3.5-trubo-0613
# 您可以设置很多个上游,程序将依次按顺序尝试
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
deny:
- gpt-4
- sk: "key_for_replicate"
type: replicate
allow:
- mistralai/mixtral-8x7b-instruct-v0.1
```
### 配置多个验证头
您可以使用英文逗号 `,` 分割多个验证头。每个验证头都是有效的,程序会记录每个请求使用的验证头
```yaml
authorization: woshimima,iampassword
```
您也可以为上游单独设置验证头
```yaml
authorization: woshimima,iampassword
upstreams:
- sk: key
authorization: woshimima
```
如此,只有携带 `woshimima` 验证头的用户可以使用该上游。
### 复杂配置示例
```yaml
# 默认验证头
authorization: woshimima
upstreams:
# 允许所有人使用的文字转语音
- sk: xxx
endpoint: http://localhost:5000/v1
noauth: true
allow:
- /v1/audio/transcriptions
# guest 专用的 gpt-3.5-turbo-0125 模型
- sk:
endpoint: https://api.xxx.local/v1
authorization: guest
allow:
- gpt-3.5-turbo-0125
```
## 部署方法
有两种推荐的部署方法:
1. 使用预先构建好的容器 `docker.io/heimoshuiyu/openai-api-route:latest`
2. 自行编译
### 使用容器运行
> 注意,如果您使用 sqlite 数据库,您可能还需要修改配置文件以将 SQLite 数据库文件放置在数据卷中。
```bash
docker run -d --name openai-api-route -v /path/to/config.yaml:/config.yaml docker.io/heimoshuiyu/openai-api-route:latest
```
使用 Docker Compose
```yaml
version: '3'
services:
openai-api-route:
image: docker.io/heimoshuiyu/openai-api-route:latest
ports:
- 8888:8888
volumes:
- ./config.yaml:/config.yaml
```
### 编译
以下是编译和运行该负载均衡 API 的步骤:
@@ -40,62 +159,28 @@
./openai-api-route
```
默认情况下API 将会在本地的 8888 端口进行监听。
## 模型允许与屏蔽列表
如果您希望使用不同的监听地址,可以使用 `-addr` 参数来指定,例如:
如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。
```
./openai-api-route -addr 0.0.0.0:8080
```
如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台
这将会将监听地址设置为 0.0.0.0:8080。
## 超时策略
6. 如果数据库不存在,系统会自动创建一个名为 `db.sqlite` 的数据库文件
在处理上游请求时,超时策略是确保服务稳定性和响应性的关键因素。本服务通过配置文件中的 `Upstreams` 部分来定义多个上游服务器。每个上游服务器都有自己的 `Endpoint` 和 `SK`(可能是密钥或特殊标识)。服务会按照配置文件中的顺序依次尝试每个上游服务器,直到请求成功或所有上游服务器都已尝试
如果您希望使用不同的数据库地址,可以使用 `-database` 参数来指定,例如:
### 单一上游配置
```
./openai-api-route -database /path/to/database.db
```
当配置文件中只定义了一个上游服务器时,该上游的超时时间将被设置为 120 秒。这意味着,如果请求没有在 120 秒内得到上游服务器的响应,服务将会中止该请求并可能返回错误。
这将会将数据库地址设置为 `/path/to/database.db`。
### 多上游配置
7. 现在,您已经成功编译并运行了负载均衡和能力 API。您可以根据需要添加上游、管理上游并使用 API 进行相关操作
如果配置文件中定义了多个上游服务器,服务将会按照定义的顺序依次尝试每个上游。对于每个上游服务器,服务会检查其 `Endpoint` 和 `SK` 是否有效。如果任一字段为空,服务将返回 500 错误,并记录无效的上游信息
### 运行
### 超时策略细节
以下是运行命令的用法
服务在处理请求时会根据不同的条件设置不同的超时时间。超时时间是指服务等待上游服务器响应的最大时间。以下是超时时间的设置规则
```
Usage of ./openai-api-route:
-addr string
监听地址(默认为 ":8888"
-upstreams string
上游配置文件(默认为 "./upstreams.yaml"
-dbtype
数据库类型 (sqlite 或 postgres默认为 sqlite)
-database string
数据库地址(默认为 "./db.sqlite"
如果数据库为 postgres ,则此值应 PostgreSQL DSN 格式
例如 "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
-list
列出所有上游
-noauth
不检查传入的授权头
```
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
以下是一个 `./upstreams.yaml` 文件配置示例
```yaml
authorization: woshimima
upstreams:
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
```
请注意,程序会根据情况修改 timeout 的值
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。

28
auth.go
View File

@@ -2,32 +2,14 @@ package main
import (
"errors"
"log"
"strings"
"github.com/gin-gonic/gin"
)
func handleAuth(c *gin.Context) error {
var err error
authorization := c.Request.Header.Get("Authorization")
if !strings.HasPrefix(authorization, "Bearer") {
err = errors.New("authorization header should start with 'Bearer'")
c.AbortWithError(403, err)
return err
}
authorization = strings.Trim(authorization[len("Bearer"):], " ")
log.Println("Received authorization", authorization)
for _, auth := range strings.Split(config.Authorization, ",") {
if authorization != strings.Trim(auth, " ") {
err = errors.New("wrong authorization header")
c.AbortWithError(403, err)
return err
func checkAuth(authorization string, config string) error {
for _, auth := range strings.Split(config, ",") {
if authorization == strings.Trim(auth, " ") {
return nil
}
}
return nil
return errors.New("wrong authorization header")
}

23
config.sample.yaml Normal file
View File

@@ -0,0 +1,23 @@
authorization: woshimima
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
deny: ["gpt-4"] # 可选的模型黑名单
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
- sk: "key_for_replicate"
type: replicate
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]

20
cors.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"log"
"github.com/gin-gonic/gin"
)
func sendCORSHeaders(c *gin.Context) {
log.Println("sendCORSHeaders")
if c.Writer.Header().Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if c.Writer.Header().Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if c.Writer.Header().Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
}

1
go.mod
View File

@@ -4,7 +4,6 @@ go 1.20
require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/penglongli/gin-metrics v0.1.10
golang.org/x/net v0.10.0
gopkg.in/yaml.v3 v3.0.1

2
go.sum
View File

@@ -148,8 +148,6 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=

190
main.go
View File

@@ -1,15 +1,21 @@
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/penglongli/gin-metrics/ginmetrics"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
@@ -18,40 +24,43 @@ import (
var config Config
func main() {
dbType := flag.String("dbtype", "sqlite", "Database type (sqlite or postgres)")
dbAddr := flag.String("database", "./db.sqlite", "Database address, if dbType is postgres, this is the DSN connection string")
configFile := flag.String("config", "./config.yaml", "Config file")
listenAddr := flag.String("addr", ":8888", "Listening address")
listMode := flag.Bool("list", false, "List all upstream")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
flag.Parse()
log.Println("Service starting")
log.Println("[main]: Service starting")
// load all upstreams
config = readConfig(*configFile)
log.Println("[main]: Load upstreams number:", len(config.Upstreams))
// connect to database
var db *gorm.DB
var err error
switch *dbType {
switch config.DBType {
case "sqlite":
db, err = gorm.Open(sqlite.Open(*dbAddr), &gorm.Config{
db, err = gorm.Open(sqlite.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatalf("[main]: Error to connect sqlite database: %s", err)
}
case "postgres":
db, err = gorm.Open(postgres.Open(*dbAddr), &gorm.Config{
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatalf("[main]: Error to connect postgres database: %s", err)
}
default:
log.Fatalf("Unsupported database type: %s", *dbType)
log.Fatalf("[main]: Unsupported database type: '%s'", config.DBType)
}
// load all upstreams
config = readConfig(*configFile)
log.Println("Load upstreams number:", len(config.Upstreams))
db.AutoMigrate(&Record{})
log.Println("Auto migrate database done")
log.Println("[main]: Auto migrate database done")
if *listMode {
fmt.Println("SK\tEndpoint")
@@ -69,6 +78,9 @@ func main() {
m.SetMetricPath("/v1/metrics")
m.Use(engine)
// CORS middleware
// engine.Use(corsMiddleware())
// error handle middleware
engine.Use(func(c *gin.Context) {
c.Next()
@@ -77,72 +89,158 @@ func main() {
}
errText := strings.Join(c.Errors.Errors(), "\n")
c.JSON(-1, gin.H{
"error": errText,
"openai-api-route error": errText,
})
})
// CORS handler
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
header := ctx.Writer.Header()
header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
// set cros header
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
ctx.AbortWithStatus(200)
})
engine.POST("/v1/*any", func(c *gin.Context) {
var err error
hostname, _ := os.Hostname()
if config.Hostname != "" {
hostname = config.Hostname
}
record := Record{
IP: c.ClientIP(),
Hostname: hostname,
CreatedAt: time.Now(),
Authorization: c.Request.Header.Get("Authorization"),
UserAgent: c.Request.Header.Get("User-Agent"),
}
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
// check authorization header
if !*noauth {
if handleAuth(c) != nil {
return
}
Model: c.Request.URL.Path,
}
for index, upstream := range config.Upstreams {
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
authorization := c.Request.Header.Get("Authorization")
if strings.HasPrefix(authorization, "Bearer") {
authorization = strings.Trim(authorization[len("Bearer"):], " ")
} else {
authorization = strings.Trim(authorization, " ")
log.Println("[auth] Warning: authorization header should start with 'Bearer'")
}
log.Println("Received authorization '" + authorization + "'")
availUpstreams := make([]OPENAI_UPSTREAM, 0)
for _, upstream := range config.Upstreams {
if upstream.SK == "" {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) %s", upstream.SK))
continue
}
if !*noauth && !upstream.Noauth {
if checkAuth(authorization, upstream.Authorization) != nil {
continue
}
}
availUpstreams = append(availUpstreams, upstream)
}
if len(availUpstreams) == 0 {
sendCORSHeaders(c)
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: no available upstream for your token"))
}
log.Println("[processRequest.begin]: availUpstreams", len(availUpstreams))
bufIO := bytes.NewBuffer(make([]byte, 0, 1024))
wrapedBody := false
for index, _upstream := range availUpstreams {
// copy
upstream := _upstream
record.UpstreamEndpoint = upstream.Endpoint
record.UpstreamSK = upstream.SK
shouldResponse := index == len(config.Upstreams)-1
if len(config.Upstreams) == 1 {
if len(availUpstreams) == 1 {
// [todo] copy problem
upstream.Timeout = 120
}
err = processRequest(c, &upstream, &record, shouldResponse)
if err != nil {
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
continue
// buffer for incoming request
if !wrapedBody {
log.Println("[processRequest.begin]: wrap request body")
c.Request.Body = io.NopCloser(io.TeeReader(c.Request.Body, bufIO))
wrapedBody = true
} else {
log.Println("[processRequest.begin]: reuse request body")
c.Request.Body = io.NopCloser(bytes.NewReader(bufIO.Bytes()))
}
break
if upstream.Type == "replicate" {
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
} else if upstream.Type == "openai" {
err = processRequest(c, &upstream, &record, shouldResponse)
} else {
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
}
if err == nil {
log.Println("[processRequest.done]: Success from upstream", upstream.Endpoint)
break
}
if err == http.ErrAbortHandler {
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
log.Println(abortErr)
record.Response += abortErr
record.Status = 500
break
}
log.Println("[processRequest.done]: Error from upstream", upstream.Endpoint, "should retry", err, "should response:", shouldResponse)
// error process, break
if shouldResponse {
c.Header("Content-Type", "application/json")
sendCORSHeaders(c)
c.AbortWithError(500, err)
}
}
log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
// parse and record request body
requestBodyBytes := bufIO.Bytes()
if len(requestBodyBytes) < 1024*1024 && (strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") ||
strings.HasPrefix(c.Request.Header.Get("Content-Type"), "text/")) {
record.Body = string(requestBodyBytes)
}
requestBody, err := ParseRequestBody(requestBodyBytes)
if err != nil {
log.Println("[processRequest.done]: Error to parse request body:", err)
} else {
record.Model = requestBody.Model
}
log.Println("[final]: Record result:", record.Status, record.Response)
record.ElapsedTime = time.Since(record.CreatedAt)
// async record request
go func() {
// encoder headers to record.Headers in json string
headers, _ := json.Marshal(c.Request.Header)
record.Headers = string(headers)
// turncate request if too long
log.Println("[async.record]: body length:", len(record.Body))
if db.Create(&record).Error != nil {
log.Println("[async.record]: Error to save record:", record)
}
}()
if record.Status != 200 {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
errMessage := fmt.Sprintf("[result.error]: IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
go sendFeishuMessage(errMessage)
go sendMatrixMessage(errMessage)
}
})
engine.Run(*listenAddr)
engine.Run(config.Address)
}

View File

@@ -8,216 +8,118 @@ import (
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/context"
)
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error
record.UpstreamEndpoint = upstream.Endpoint
record.UpstreamSK = upstream.SK
record.Response = ""
// [TODO] record request body
// reverse proxy
remote, err := url.Parse(upstream.Endpoint)
if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return err
}
haveResponse := false
path := strings.TrimPrefix(c.Request.URL.Path, "/v1")
remote.Path = upstream.URL.Path + path
log.Println("[proxy.begin]:", remote)
log.Println("[proxy.begin]: shouldResposne:", shouldResponse)
proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = nil
var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In
client := &http.Client{}
request := &http.Request{}
request.ContentLength = c.Request.ContentLength
request.Method = c.Request.Method
request.URL = remote
ctx, cancel := context.WithCancel(context.Background())
proxyRequest.Out = proxyRequest.Out.WithContext(ctx)
out := proxyRequest.Out
// read request body
inBody, err = io.ReadAll(in.Body)
if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
return
}
// record chat message from user
record.Body = string(inBody)
requestBody, requestBodyOK := ParseRequestBody(inBody)
// record if parse success
if requestBodyOK == nil {
record.Model = requestBody.Model
}
// set timeout, default is 60 second
timeout := 60 * time.Second
if requestBodyOK == nil && requestBody.Stream {
timeout = 5 * time.Second
}
if len(inBody) > 1024*128 {
timeout = 20 * time.Second
}
if upstream.Timeout > 0 {
// convert upstream.Timeout(second) to nanosecond
timeout = time.Duration(upstream.Timeout) * time.Second
}
// timeout out request
go func() {
time.Sleep(timeout)
if !haveResponse {
log.Println("Timeout upstream", upstream.Endpoint)
errCtx = errors.New("timeout")
if shouldResponse {
c.AbortWithError(502, errCtx)
}
cancel()
}
}()
out.Body = io.NopCloser(bytes.NewReader(inBody))
out.Host = remote.Host
out.URL.Scheme = remote.Scheme
out.URL.Host = remote.Host
out.URL.Path = in.URL.Path
out.Header = http.Header{}
out.Header.Set("Host", remote.Host)
if upstream.SK == "asis" {
out.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
out.Header.Set("Authorization", "Bearer "+upstream.SK)
}
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
var buf bytes.Buffer
var contentType string
proxy.ModifyResponse = func(r *http.Response) error {
haveResponse = true
record.Status = r.StatusCode
if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response")
}
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body)
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
return errors.New(record.Response)
}
record.Response = fmt.Sprintf("openai-api-route upstream return '%s' with '%s'", r.Status, string(body))
record.Status = r.StatusCode
return fmt.Errorf(record.Response)
}
// count success
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
contentType = r.Header.Get("content-type")
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
haveResponse = true
record.ResponseTime = time.Now().Sub(record.CreatedAt)
log.Println("Error", err, upstream.SK, upstream.Endpoint)
errCtx = err
// abort to error handle
if shouldResponse {
c.AbortWithError(502, err)
}
log.Println("response is", r.Response)
if record.Status == 0 {
record.Status = 502
}
if record.Response == "" {
record.Response = err.Error()
}
if r.Response != nil {
record.Status = r.Response.StatusCode
}
}
proxy.ServeHTTP(c.Writer, c.Request)
// return context error
if errCtx != nil {
// fix inrequest body
c.Request.Body = io.NopCloser(bytes.NewReader(inBody))
return errCtx
}
resp, err := io.ReadAll(io.NopCloser(&buf))
if err != nil {
record.Response = "failed to read response from upstream " + err.Error()
log.Println(record.Response)
// process header
if upstream.KeepHeader {
request.Header = c.Request.Header
} else {
// record response
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
for _, line := range strings.Split(string(resp), "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(resp, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return nil
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return nil
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return nil
}
record.Response = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
}
request.Header = http.Header{}
}
if len(record.Body) > 1024*512 {
record.Body = ""
// process header authorization
if upstream.SK == "asis" {
request.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} else {
request.Header.Set("Authorization", "Bearer "+upstream.SK)
}
request.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
request.Header.Set("Host", remote.Host)
request.Header.Set("Content-Length", c.Request.Header.Get("Content-Length"))
request.Body = c.Request.Body
resp, err := client.Do(request)
if err != nil {
body := []byte{}
if resp != nil && resp.Body != nil {
body, _ = io.ReadAll(resp.Body)
}
return errors.New(err.Error() + " " + string(body))
}
defer resp.Body.Close()
record.Status = resp.StatusCode
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
record.Status = resp.StatusCode
errRet := fmt.Errorf("[error]: openai-api-route upstream return '%s' with '%s'", resp.Status, string(body))
log.Println(errRet)
return errRet
}
// copy response header
for k, v := range resp.Header {
c.Header(k, v[0])
}
sendCORSHeaders(c)
respBodyBuffer := bytes.NewBuffer(make([]byte, 0, 4*1024))
respBodyTeeReader := io.TeeReader(resp.Body, respBodyBuffer)
record.ResponseTime = time.Since(record.CreatedAt)
io.Copy(c.Writer, respBodyTeeReader)
record.ElapsedTime = time.Since(record.CreatedAt)
// parse and record response
if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
var fetchResp FetchModeResponse
err := json.NewDecoder(respBodyBuffer).Decode(&fetchResp)
if err == nil {
if len(fetchResp.Choices) > 0 {
record.Response = fetchResp.Choices[0].Message.Content
}
}
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") {
lines := bytes.Split(respBodyBuffer.Bytes(), []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
line = bytes.TrimPrefix(line, []byte("data:"))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
chunk := StreamModeChunk{}
err = json.Unmarshal(line, &chunk)
if err != nil {
log.Println("[proxy.parseChunkError]:", err)
break
}
if len(chunk.Choices) == 0 {
continue
}
record.Response += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(resp.Header.Get("Content-Type"), "text") {
body, _ := io.ReadAll(respBodyBuffer)
record.Response = string(body)
} else {
log.Println("[proxy.record]: Unknown content type", resp.Header.Get("Content-Type"))
}
return nil

View File

@@ -1,22 +1,17 @@
package main
import (
"encoding/json"
"log"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Record struct {
ID int64 `gorm:"primaryKey,autoIncrement"`
Hostname string
UpstreamEndpoint string
UpstreamSK string
CreatedAt time.Time
IP string
Body string `gorm:"serializer:json"`
Body string
Model string
Response string
ResponseTime time.Duration
@@ -24,6 +19,7 @@ type Record struct {
Status int
Authorization string // the autorization header send by client
UserAgent string
Headers string
}
type StreamModeChunk struct {
@@ -55,61 +51,3 @@ type FetchModeUsage struct {
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
func recordAssistantResponse(contentType string, db *gorm.DB, trackID uuid.UUID, body []byte, elapsedTime time.Duration) {
result := ""
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
resp := string(body)
for _, line := range strings.Split(resp, "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
result += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(body, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return
}
result = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
return
}
log.Println("Record result:", result)
record := Record{}
if db.Find(&record, "id = ?", trackID).Error != nil {
log.Println("Error find request record with trackID:", trackID)
return
}
record.Response = result
record.ElapsedTime = elapsedTime
if db.Save(&record).Error != nil {
log.Println("Error to save record:", record)
return
}
}

24
recovery.go Normal file
View File

@@ -0,0 +1,24 @@
package main
import (
"errors"
"log"
"net/http"
"net/http/httputil"
"github.com/gin-gonic/gin"
)
func ServeHTTP(proxy *httputil.ReverseProxy, w gin.ResponseWriter, r *http.Request) (errReturn error) {
// recovery
defer func() {
if err := recover(); err != nil {
log.Println("[serve.panic]: ", err)
errReturn = errors.New("[serve.panic]: Panic recover in reverse proxy serve HTTP")
}
}()
proxy.ServeHTTP(w, r)
return nil
}

369
replicate.go Normal file
View File

@@ -0,0 +1,369 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
err := _processReplicateRequest(c, upstream, record, shouldResponse)
if shouldResponse {
sendCORSHeaders(c)
if err != nil {
c.AbortWithError(502, err)
}
}
return err
}
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
// read request body
inBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
}
// record request body
// parse request body
inRequest := &OpenAIChatRequest{}
err = json.Unmarshal(inBody, inRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
}
record.Model = inRequest.Model
// check allow model
if len(upstream.Allow) > 0 {
isAllow := false
for _, model := range upstream.Allow {
if model == inRequest.Model {
isAllow = true
break
}
}
if !isAllow {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model not allow")
}
}
// check block model
if len(upstream.Deny) > 0 {
for _, model := range upstream.Deny {
if model == inRequest.Model {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model deny")
}
}
}
// set url
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
log.Println("[processReplicateRequest]: model_url:", model_url)
// create request with default value
outRequest := &ReplicateModelRequest{
Stream: false,
Input: ReplicateModelRequestInput{
Prompt: "",
MaxNewTokens: 1024,
Temperature: 0.6,
Top_p: 0.9,
Top_k: 50,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
PromptTemplate: "{prompt}",
},
}
// copy value from input request
outRequest.Stream = inRequest.Stream
outRequest.Input.Temperature = inRequest.Temperature
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt
systemMessage := ""
userMessage := ""
assistantMessage := ""
for _, message := range inRequest.Messages {
if message.Role == "system" {
if systemMessage != "" {
systemMessage += "\n"
}
systemMessage += message.Content
continue
}
if message.Role == "user" {
if userMessage != "" {
userMessage += "\n"
}
userMessage += message.Content
if systemMessage != "" {
userMessage = systemMessage + "\n" + userMessage
systemMessage = ""
}
continue
}
if message.Role == "assistant" {
if assistantMessage != "" {
assistantMessage += "\n"
}
assistantMessage += message.Content
if outRequest.Input.Prompt != "" {
outRequest.Input.Prompt += "\n"
}
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
} else {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
userMessage = ""
assistantMessage = ""
}
// unknown role
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
}
// final user message
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
userMessage = ""
}
// final assistant message
if assistantMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
// send request
outBody, err := json.Marshal(outRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
}
// http add headers
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
}
// read response body
outBody, err = io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// parse reponse body
outResponse := &ReplicateModelResponse{}
err = json.Unmarshal(outBody, outResponse)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
if outResponse.Stream {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
req.Header.Set("Accept", "text/event-stream")
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result by chunk
var buffer string = ""
var indexCount int64 = 0
for {
if !strings.Contains(buffer, "\n\n") {
// receive chunk
chunk := make([]byte, 1024)
length, err := resp.Body.Read(chunk)
if err == io.EOF {
break
}
if length == 0 {
break
}
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// add chunk to buffer
chunk = bytes.Trim(chunk, "\x00")
buffer += string(chunk)
continue
}
// cut the first chunk by find index
index := strings.Index(buffer, "\n\n")
chunk := buffer[:index]
buffer = buffer[index+2:]
// trim line
chunk = strings.Trim(chunk, "\n")
// ignore hi
if !strings.Contains(chunk, "\n") {
continue
}
// parse chunk to ReplicateModelResultChunk object
chunkObj := &ReplicateModelResultChunk{}
lines := strings.Split(chunk, "\n")
// first line is event
chunkObj.Event = strings.TrimSpace(lines[0])
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
// second line is id
chunkObj.ID = strings.TrimSpace(lines[1])
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
// third line is data
chunkObj.Data = lines[2]
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
record.Response += chunkObj.Data
// done
if chunkObj.Event == "done" {
break
}
sendCORSHeaders(c)
// create OpenAI response chunk
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Delta: OpenAIChatMessage{
Role: "assistant",
Content: chunkObj.Data,
},
},
},
})
c.Writer.Flush()
indexCount += 1
}
sendCORSHeaders(c)
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Delta: OpenAIChatMessage{
Role: "assistant",
Content: "",
},
FinishReason: "stop",
},
},
})
c.Writer.Flush()
indexCount += 1
record.Status = 200
return nil
} else {
var result *ReplicateModelResultGet
for {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result
resultBody, err := io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// parse reponse body
result = &ReplicateModelResultGet{}
err = json.Unmarshal(resultBody, result)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
if result.Status == "processing" || result.Status == "starting" {
time.Sleep(3 * time.Second)
continue
}
break
}
// build openai resposne
openAIResult := &OpenAIChatResponse{
ID: result.ID,
Model: result.Model,
Choices: []OpenAIChatResponseChoice{},
Usage: OpenAIChatResponseUsage{
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
PromptTokens: result.Metrics.InputTokenCount,
},
}
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
Index: 0,
Message: OpenAIChatMessage{
Role: "assistant",
Content: strings.Join(result.Output, ""),
},
FinishReason: "stop",
})
record.Response = strings.Join(result.Output, "")
record.Status = 200
// gin return
sendCORSHeaders(c)
c.JSON(200, openAIResult)
}
return nil
}

View File

@@ -2,19 +2,34 @@ package main
import (
"log"
"net/url"
"os"
"gopkg.in/yaml.v3"
)
type Config struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
DBType string `yaml:"dbtype"`
DBAddr string `yaml:"dbaddr"`
Authorization string `yaml:"authorization"`
Timeout int64 `yaml:"timeout"`
StreamTimeout int64 `yaml:"stream_timeout"`
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
}
type OPENAI_UPSTREAM struct {
SK string `yaml:"sk"`
Endpoint string `yaml:"endpoint"`
Timeout int64 `yaml:"timeout"`
SK string `yaml:"sk"`
Endpoint string `yaml:"endpoint"`
Timeout int64 `yaml:"timeout"`
StreamTimeout int64 `yaml:"stream_timeout"`
Allow []string `yaml:"allow"`
Deny []string `yaml:"deny"`
Type string `yaml:"type"`
KeepHeader bool `yaml:"keep_header"`
Authorization string `yaml:"authorization"`
Noauth bool `yaml:"noauth"`
URL *url.URL
}
func readConfig(filepath string) Config {
@@ -32,5 +47,158 @@ func readConfig(filepath string) Config {
log.Fatalf("Error unmarshaling YAML: %s", err)
}
// set default value
if config.Address == "" {
log.Println("Address not set, use default value: :8888")
config.Address = ":8888"
}
if config.DBType == "" {
log.Println("DBType not set, use default value: sqlite")
config.DBType = "sqlite"
}
if config.DBAddr == "" {
log.Println("DBAddr not set, use default value: ./db.sqlite")
config.DBAddr = "./db.sqlite"
}
if config.Timeout == 0 {
log.Println("Timeout not set, use default value: 120")
config.Timeout = 120
}
if config.StreamTimeout == 0 {
log.Println("StreamTimeout not set, use default value: 10")
config.StreamTimeout = 10
}
for i, upstream := range config.Upstreams {
// parse upstream endpoint URL
endpoint, err := url.Parse(upstream.Endpoint)
if err != nil {
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
}
config.Upstreams[i].URL = endpoint
if config.Upstreams[i].Type == "" {
config.Upstreams[i].Type = "openai"
}
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
}
// apply authorization from global config if not set
if config.Upstreams[i].Authorization == "" && !config.Upstreams[i].Noauth {
config.Upstreams[i].Authorization = config.Authorization
}
if config.Upstreams[i].Timeout == 0 {
config.Upstreams[i].Timeout = config.Timeout
}
if config.Upstreams[i].StreamTimeout == 0 {
config.Upstreams[i].StreamTimeout = config.StreamTimeout
}
}
return config
}
type OpenAIChatRequest struct {
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
MaxTokens int64 `json:"max_tokens"`
Model string `json:"model"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
Messages []OpenAIChatRequestMessage
}
type OpenAIChatRequestMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type ReplicateModelRequest struct {
Stream bool `json:"stream"`
Input ReplicateModelRequestInput `json:"input"`
}
type ReplicateModelRequestInput struct {
Prompt string `json:"prompt"`
MaxNewTokens int64 `json:"max_new_tokens"`
Temperature float64 `json:"temperature"`
Top_p float64 `json:"top_p"`
Top_k int64 `json:"top_k"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PromptTemplate string `json:"prompt_template"`
}
type ReplicateModelResponse struct {
Model string `json:"model"`
Version string `json:"version"`
Stream bool `json:"stream"`
Error string `json:"error"`
URLS ReplicateModelResponseURLS `json:"urls"`
}
type ReplicateModelResponseURLS struct {
Cancel string `json:"cancel"`
Get string `json:"get"`
Stream string `json:"stream"`
}
type ReplicateModelResultGet struct {
ID string `json:"id"`
Model string `json:"model"`
Version string `json:"version"`
Output []string `json:"output"`
Error string `json:"error"`
Metrics ReplicateModelResultMetrics `json:"metrics"`
Status string `json:"status"`
}
type ReplicateModelResultMetrics struct {
InputTokenCount int64 `json:"input_token_count"`
OutputTokenCount int64 `json:"output_token_count"`
}
type OpenAIChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChatResponseChoice `json:"choices"`
Usage OpenAIChatResponseUsage `json:"usage"`
}
type OpenAIChatResponseUsage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
type OpenAIChatResponseChoice struct {
Index int64 `json:"index"`
Message OpenAIChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ReplicateModelResultChunk struct {
Event string `json:"event"`
ID string `json:"id"`
Data string `json:"data"`
}
type OpenAIChatResponseChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChatResponseChunkChoice `json:"choices"`
}
type OpenAIChatResponseChunkChoice struct {
Index int64 `json:"index"`
Delta OpenAIChatMessage `json:"delta"`
FinishReason string `json:"finish_reason"`
}