convert config to yaml file

This commit is contained in:
2023-11-27 18:07:37 +08:00
parent 3cc507a767
commit de3bea06a7
8 changed files with 28 additions and 83 deletions

View File

@@ -1,3 +1,3 @@
openai-api-route openai-api-route
db.sqlite db.sqlite
/upstreams.yaml /config.yaml

2
.gitignore vendored
View File

@@ -1,3 +1,3 @@
openai-api-route openai-api-route
db.sqlite db.sqlite
/upstreams.yaml /config.yaml

View File

@@ -83,6 +83,8 @@ Usage of ./openai-api-route:
以下是一个 `./upstreams.yaml` 文件配置示例 以下是一个 `./upstreams.yaml` 文件配置示例
```yaml ```yaml
authorization: woshimima
upstreams:
- sk: "secret_key_1" - sk: "secret_key_1"
endpoint: "https://api.openai.com/v2" endpoint: "https://api.openai.com/v2"
- sk: "secret_key_2" - sk: "secret_key_2"

View File

@@ -21,7 +21,7 @@ func handleAuth(c *gin.Context) error {
authorization = strings.Trim(authorization[len("Bearer"):], " ") authorization = strings.Trim(authorization[len("Bearer"):], " ")
log.Println("Received authorization", authorization) log.Println("Received authorization", authorization)
for _, auth := range strings.Split(authConfig.Value, ",") { for _, auth := range strings.Split(config.Authorization, ",") {
if authorization != strings.Trim(auth, " ") { if authorization != strings.Trim(auth, " ") {
err = errors.New("wrong authorization header") err = errors.New("wrong authorization header")
c.AbortWithError(403, err) c.AbortWithError(403, err)

View File

@@ -1,49 +0,0 @@
package main
import (
"errors"
"log"
"gorm.io/gorm"
)
// K-V struct to store program's config
type ConfigKV struct {
gorm.Model
Key string `gorm:"unique"`
Value string
}
// init db
func initconfig(db *gorm.DB) error {
var err error
err = db.AutoMigrate(&ConfigKV{})
if err != nil {
return err
}
// config list and their default values
configs := make(map[string]string)
configs["authorization"] = "woshimima"
configs["policy"] = "main"
for key, value := range configs {
kv := ConfigKV{}
err = db.Take(&kv, "key = ?", key).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Println("Missing config", key, "creating with value", value)
kv.Key = key
kv.Value = value
if err = db.Create(&kv).Error; err != nil {
return err
}
} else {
return err
}
}
}
return nil
}

View File

@@ -1,5 +0,0 @@
package main
// declare global variable
var authConfig ConfigKV

26
main.go
View File

@@ -13,9 +13,12 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// global config
var config Config
func main() { func main() {
dbAddr := flag.String("database", "./db.sqlite", "Database address") dbAddr := flag.String("database", "./db.sqlite", "Database address")
upstreamsFile := flag.String("upstreams", "./upstreams.yaml", "Upstreams file") configFile := flag.String("config", "./config.yaml", "Config file")
listenAddr := flag.String("addr", ":8888", "Listening address") listenAddr := flag.String("addr", ":8888", "Listening address")
listMode := flag.Bool("list", false, "List all upstream") listMode := flag.Bool("list", false, "List all upstream")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header") noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
@@ -33,21 +36,15 @@ func main() {
} }
// load all upstreams // load all upstreams
upstreams := readUpstreams(*upstreamsFile) config = readConfig(*configFile)
log.Println("Load upstreams number:", len(upstreams)) log.Println("Load upstreams number:", len(config.Upstreams))
err = initconfig(db)
if err != nil {
log.Fatal(err)
}
db.AutoMigrate(&OPENAI_UPSTREAM{})
db.AutoMigrate(&Record{}) db.AutoMigrate(&Record{})
log.Println("Auto migrate database done") log.Println("Auto migrate database done")
if *listMode { if *listMode {
fmt.Println("SK\tEndpoint") fmt.Println("SK\tEndpoint")
for _, upstream := range upstreams { for _, upstream := range config.Upstreams {
fmt.Println(upstream.SK, upstream.Endpoint) fmt.Println(upstream.SK, upstream.Endpoint)
} }
return return
@@ -82,9 +79,6 @@ func main() {
ctx.AbortWithStatus(200) ctx.AbortWithStatus(200)
}) })
// get authorization config from db
db.Take(&authConfig, "key = ?", "authorization")
engine.POST("/v1/*any", func(c *gin.Context) { engine.POST("/v1/*any", func(c *gin.Context) {
record := Record{ record := Record{
IP: c.ClientIP(), IP: c.ClientIP(),
@@ -105,15 +99,15 @@ func main() {
} }
} }
for index, upstream := range upstreams { for index, upstream := range config.Upstreams {
if upstream.Endpoint == "" || upstream.SK == "" { if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
continue continue
} }
shouldResponse := index == len(upstreams)-1 shouldResponse := index == len(config.Upstreams)-1
if len(upstreams) == 1 { if len(config.Upstreams) == 1 {
upstream.Timeout = 120 upstream.Timeout = 120
} }

View File

@@ -7,15 +7,18 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// one openai upstream contain a pair of key and endpoint type Config struct {
Authorization string `yaml:"authorization"`
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
}
type OPENAI_UPSTREAM struct { type OPENAI_UPSTREAM struct {
SK string `yaml:"sk"` SK string `yaml:"sk"`
Endpoint string `yaml:"endpoint"` Endpoint string `yaml:"endpoint"`
Timeout int64 `yaml:"timeout"` Timeout int64 `yaml:"timeout"`
} }
func readUpstreams(filepath string) []OPENAI_UPSTREAM { func readConfig(filepath string) Config {
var upstreams []OPENAI_UPSTREAM var config Config
// read yaml file // read yaml file
data, err := os.ReadFile(filepath) data, err := os.ReadFile(filepath)
@@ -24,10 +27,10 @@ func readUpstreams(filepath string) []OPENAI_UPSTREAM {
} }
// Unmarshal the YAML into the upstreams slice // Unmarshal the YAML into the upstreams slice
err = yaml.Unmarshal(data, &upstreams) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
log.Fatalf("Error unmarshaling YAML: %s", err) log.Fatalf("Error unmarshaling YAML: %s", err)
} }
return upstreams return config
} }