5 Commits

Author SHA1 Message Date
dad4ad2b97 update readme.md 2023-11-27 17:28:29 +08:00
fb19d8a353 upstreams.yaml 2023-11-27 17:19:08 +08:00
4125c78f33 record response time 2023-11-17 14:53:49 +08:00
31eed99025 support multiple auth header 2023-11-16 15:48:09 +08:00
6c9eab09e2 record authorization 2023-11-16 15:42:37 +08:00
7 changed files with 74 additions and 71 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
openai-api-route openai-api-route
db.sqlite db.sqlite
/upstreams.yaml

View File

@@ -68,30 +68,28 @@
``` ```
Usage of ./openai-api-route: Usage of ./openai-api-route:
-add
添加一个 OpenAI 上游
-addr string -addr string
监听地址(默认为 ":8888" 监听地址(默认为 ":8888"
-upstreams string
上游配置文件(默认为 "./upstreams.yaml"
-database string -database string
数据库地址(默认为 "./db.sqlite" 数据库地址(默认为 "./db.sqlite"
-endpoint string
OpenAI API 基地址(默认为 "https://api.openai.com/v1"
-list -list
列出所有上游 列出所有上游
-noauth -noauth
不检查传入的授权头 不检查传入的授权头
-sk string
OpenAI API 密钥sk-xxxxx
``` ```
以下是一个 `./upstreams.yaml` 文件配置示例
```yaml
- sk: "secret_key_1"
endpoint: "https://api.openai.com/v2"
- sk: "secret_key_2"
endpoint: "https://api.openai.com/v1"
timeout: 30
```
请注意,程序会根据情况修改 timeout 的值
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。 您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
### 上游管理
您可以使用以下命令添加一个上游:
```bash
./openai-api-route -add -sk sk-xxxxx -endpoint https://api.openai.com/v1
```
另外,您还可以直接编辑数据库中的 `openai_upstreams` 表进行 OpenAI 上游的增删改查管理。改动的上游需要重启负载均衡服务后才能生效。

View File

@@ -21,11 +21,13 @@ func handleAuth(c *gin.Context) error {
authorization = strings.Trim(authorization[len("Bearer"):], " ") authorization = strings.Trim(authorization[len("Bearer"):], " ")
log.Println("Received authorization", authorization) log.Println("Received authorization", authorization)
if authorization != authConfig.Value { for _, auth := range strings.Split(authConfig.Value, ",") {
if authorization != strings.Trim(auth, " ") {
err = errors.New("wrong authorization header") err = errors.New("wrong authorization header")
c.AbortWithError(403, err) c.AbortWithError(403, err)
return err return err
} }
}
return nil return nil
} }

26
main.go
View File

@@ -15,12 +15,10 @@ import (
func main() { func main() {
dbAddr := flag.String("database", "./db.sqlite", "Database address") dbAddr := flag.String("database", "./db.sqlite", "Database address")
upstreamsFile := flag.String("upstreams", "./upstreams.yaml", "Upstreams file")
listenAddr := flag.String("addr", ":8888", "Listening address") listenAddr := flag.String("addr", ":8888", "Listening address")
addMode := flag.Bool("add", false, "Add an OpenAI upstream")
listMode := flag.Bool("list", false, "List all upstream") listMode := flag.Bool("list", false, "List all upstream")
sk := flag.String("sk", "", "OpenAI API key (sk-xxxxx)")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header") noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
endpoint := flag.String("endpoint", "https://api.openai.com/v1", "OpenAI API base")
flag.Parse() flag.Parse()
log.Println("Service starting") log.Println("Service starting")
@@ -35,8 +33,7 @@ func main() {
} }
// load all upstreams // load all upstreams
upstreams := make([]OPENAI_UPSTREAM, 0) upstreams := readUpstreams(*upstreamsFile)
db.Find(&upstreams)
log.Println("Load upstreams number:", len(upstreams)) log.Println("Load upstreams number:", len(upstreams))
err = initconfig(db) err = initconfig(db)
@@ -48,26 +45,9 @@ func main() {
db.AutoMigrate(&Record{}) db.AutoMigrate(&Record{})
log.Println("Auto migrate database done") log.Println("Auto migrate database done")
if *addMode {
if *sk == "" {
log.Fatal("Missing --sk flag")
}
newUpstream := OPENAI_UPSTREAM{}
newUpstream.SK = *sk
newUpstream.Endpoint = *endpoint
err = db.Create(&newUpstream).Error
if err != nil {
log.Fatal("Can not add upstream", err)
}
log.Println("Successuflly add upstream", *sk, *endpoint)
return
}
if *listMode { if *listMode {
result := make([]OPENAI_UPSTREAM, 0)
db.Find(&result)
fmt.Println("SK\tEndpoint") fmt.Println("SK\tEndpoint")
for _, upstream := range result { for _, upstream := range upstreams {
fmt.Println(upstream.SK, upstream.Endpoint) fmt.Println(upstream.SK, upstream.Endpoint)
} }
return return

View File

@@ -20,9 +20,8 @@ import (
func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
var errCtx error var errCtx error
record.UpstreamID = upstream.ID record.UpstreamEndpoint = upstream.Endpoint
record.Response = "" record.Response = ""
record.Authorization = upstream.SK
// [TODO] record request body // [TODO] record request body
// reverse proxy // reverse proxy
@@ -134,6 +133,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
} }
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
haveResponse = true haveResponse = true
record.ResponseTime = time.Now().Sub(record.CreatedAt)
log.Println("Error", err, upstream.SK, upstream.Endpoint) log.Println("Error", err, upstream.SK, upstream.Endpoint)
errCtx = err errCtx = err

View File

@@ -12,11 +12,13 @@ import (
type Record struct { type Record struct {
ID int64 `gorm:"primaryKey,autoIncrement"` ID int64 `gorm:"primaryKey,autoIncrement"`
UpstreamEndpoint string
CreatedAt time.Time CreatedAt time.Time
IP string IP string
Body string `gorm:"serializer:json"` Body string `gorm:"serializer:json"`
Model string Model string
Response string Response string
ResponseTime time.Duration
ElapsedTime time.Duration ElapsedTime time.Duration
Status int Status int
UpstreamID uint UpstreamID uint

View File

@@ -1,13 +1,33 @@
package main package main
import ( import (
"gorm.io/gorm" "log"
"os"
"gopkg.in/yaml.v3"
) )
// one openai upstream contain a pair of key and endpoint // one openai upstream contain a pair of key and endpoint
type OPENAI_UPSTREAM struct { type OPENAI_UPSTREAM struct {
gorm.Model SK string `yaml:"sk"`
SK string `gorm:"index:idx_sk_endpoint,unique"` // key Endpoint string `yaml:"endpoint"`
Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint Timeout int64 `yaml:"timeout"`
Timeout int64 // timeout in seconds }
func readUpstreams(filepath string) []OPENAI_UPSTREAM {
var upstreams []OPENAI_UPSTREAM
// read yaml file
data, err := os.ReadFile(filepath)
if err != nil {
log.Fatalf("Error reading YAML file: %s", err)
}
// Unmarshal the YAML into the upstreams slice
err = yaml.Unmarshal(data, &upstreams)
if err != nil {
log.Fatalf("Error unmarshaling YAML: %s", err)
}
return upstreams
} }