Compare commits

...

5 Commits

Author SHA1 Message Date
0dbd898532 update config yaml 2023-11-29 14:40:20 +08:00
7b74818676 update config to file 2023-11-29 14:31:11 +08:00
0a2a7376f1 update dep 2023-11-28 14:36:14 +08:00
44a966e6f4 clean code 2023-11-28 14:04:42 +08:00
87244d4dc2 update docs 2023-11-28 10:42:06 +08:00
7 changed files with 76 additions and 80 deletions

View File

@@ -8,7 +8,7 @@
- 支持所有类型的接口 (`/v1/*`)
- 提供 Prometheus Metrics 统计接口 (`/v1/metrics`)
- 按照定义顺序请求 OpenAI 上游
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。对于其他请求使用 60 秒超时。
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
- 请求出错时发送 飞书 或 Matrix 消息通知
@@ -88,6 +88,15 @@ Usage of ./openai-api-route:
```yaml
authorization: woshimima
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
@@ -99,3 +108,27 @@ upstreams:
请注意,程序会根据情况修改 timeout 的值
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
## 超时策略
在处理上游请求时,超时策略是确保服务稳定性和响应性的关键因素。本服务通过配置文件中的 `Upstreams` 部分来定义多个上游服务器。每个上游服务器都有自己的 `Endpoint` 和 `SK`(可能是密钥或特殊标识)。服务会按照配置文件中的顺序依次尝试每个上游服务器,直到请求成功或所有上游服务器都已尝试。
### 单一上游配置
当配置文件中只定义了一个上游服务器时,该上游的超时时间将被设置为 120 秒。这意味着,如果请求没有在 120 秒内得到上游服务器的响应,服务将会中止该请求并可能返回错误。
### 多上游配置
如果配置文件中定义了多个上游服务器,服务将会按照定义的顺序依次尝试每个上游。对于每个上游服务器,服务会检查其 `Endpoint` 和 `SK` 是否有效。如果任一字段为空,服务将返回 500 错误,并记录无效的上游信息。
### 超时策略细节
服务在处理请求时会根据不同的条件设置不同的超时时间。超时时间是指服务等待上游服务器响应的最大时间。以下是超时时间的设置规则:
1. **默认超时时间**:如果没有特殊条件,服务将使用默认的超时时间,即 60 秒。
2. **流式请求**:如果请求体被识别为流式(`requestBody.Stream` 为 `true`),并且请求体检查(`requestBodyOK`)没有发现问题,超时时间将被设置为 5 秒。这适用于那些预期会快速响应的流式请求。
3. **大请求体**:如果请求体的大小超过 128KB即 `len(inBody) > 1024*128`),超时时间将被设置为 20 秒。这考虑到了处理大型数据可能需要更长的时间。
4. **上游超时配置**:如果上游服务器在配置中指定了超时时间(`upstream.Timeout` 大于 0服务将使用该值作为超时时间。这个值是以秒为单位的。

16
config.sample.yaml Normal file
View File

@@ -0,0 +1,16 @@
authorization: woshimima
# 使用 sqlite 作为数据库储存请求记录
dbtype: sqlite
dbaddr: ./db.sqlite
# 使用 postgres 作为数据库储存请求记录
# dbtype: postgres
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
upstreams:
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30

1
go.mod
View File

@@ -4,7 +4,6 @@ go 1.20
require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/penglongli/gin-metrics v0.1.10
golang.org/x/net v0.10.0
gopkg.in/yaml.v3 v3.0.1

2
go.sum
View File

@@ -148,8 +148,6 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=

21
main.go
View File

@@ -18,38 +18,35 @@ import (
var config Config
func main() {
dbType := flag.String("dbtype", "sqlite", "Database type (sqlite or postgres)")
dbAddr := flag.String("database", "./db.sqlite", "Database address, if dbType is postgres, this is the DSN connection string")
configFile := flag.String("config", "./config.yaml", "Config file")
listenAddr := flag.String("addr", ":8888", "Listening address")
listMode := flag.Bool("list", false, "List all upstream")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
flag.Parse()
log.Println("Service starting")
// load all upstreams
config = readConfig(*configFile)
log.Println("Load upstreams number:", len(config.Upstreams))
// connect to database
var db *gorm.DB
var err error
switch *dbType {
switch config.DBType {
case "sqlite":
db, err = gorm.Open(sqlite.Open(*dbAddr), &gorm.Config{
db, err = gorm.Open(sqlite.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
case "postgres":
db, err = gorm.Open(postgres.Open(*dbAddr), &gorm.Config{
db, err = gorm.Open(postgres.Open(config.DBAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
default:
log.Fatalf("Unsupported database type: %s", *dbType)
log.Fatalf("Unsupported database type: '%s'", config.DBType)
}
// load all upstreams
config = readConfig(*configFile)
log.Println("Load upstreams number:", len(config.Upstreams))
db.AutoMigrate(&Record{})
log.Println("Auto migrate database done")
@@ -144,5 +141,5 @@ func main() {
}
})
engine.Run(*listenAddr)
engine.Run(config.Address)
}

View File

@@ -1,13 +1,7 @@
package main
import (
"encoding/json"
"log"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Record struct {
@@ -55,61 +49,3 @@ type FetchModeUsage struct {
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
func recordAssistantResponse(contentType string, db *gorm.DB, trackID uuid.UUID, body []byte, elapsedTime time.Duration) {
result := ""
// stream mode
if strings.HasPrefix(contentType, "text/event-stream") {
resp := string(body)
for _, line := range strings.Split(resp, "\n") {
chunk := StreamModeChunk{}
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
if line == "" {
continue
}
err := json.Unmarshal([]byte(line), &chunk)
if err != nil {
log.Println(err)
continue
}
if len(chunk.Choices) == 0 {
continue
}
result += chunk.Choices[0].Delta.Content
}
} else if strings.HasPrefix(contentType, "application/json") {
var fetchResp FetchModeResponse
err := json.Unmarshal(body, &fetchResp)
if err != nil {
log.Println("Error parsing fetch response:", err)
return
}
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
log.Println("Not GPT model, skip recording response:", fetchResp.Model)
return
}
if len(fetchResp.Choices) == 0 {
log.Println("Error: fetch response choice length is 0")
return
}
result = fetchResp.Choices[0].Message.Content
} else {
log.Println("Unknown content type", contentType)
return
}
log.Println("Record result:", result)
record := Record{}
if db.Find(&record, "id = ?", trackID).Error != nil {
log.Println("Error find request record with trackID:", trackID)
return
}
record.Response = result
record.ElapsedTime = elapsedTime
if db.Save(&record).Error != nil {
log.Println("Error to save record:", record)
return
}
}

View File

@@ -8,6 +8,9 @@ import (
)
type Config struct {
Address string `yaml:"address"`
DBType string `yaml:"dbtype"`
DBAddr string `yaml:"dbaddr"`
Authorization string `yaml:"authorization"`
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
}
@@ -32,5 +35,19 @@ func readConfig(filepath string) Config {
log.Fatalf("Error unmarshaling YAML: %s", err)
}
// set default value
if config.Address == "" {
log.Println("Address not set, use default value: :8888")
config.Address = ":8888"
}
if config.DBType == "" {
log.Println("DBType not set, use default value: sqlite")
config.DBType = "sqlite"
}
if config.DBAddr == "" {
log.Println("DBAddr not set, use default value: ./db.sqlite")
config.DBAddr = "./db.sqlite"
}
return config
}