From de3bea06a7966ae8c18421837e45f3615f4ec77b Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Mon, 27 Nov 2023 18:07:37 +0800 Subject: [PATCH] convert config to yaml file --- .dockerignore | 2 +- .gitignore | 2 +- README.md | 12 +++++++----- auth.go | 2 +- config.go | 49 ------------------------------------------------- global.go | 5 ----- main.go | 26 ++++++++++---------------- structure.go | 13 ++++++++----- 8 files changed, 28 insertions(+), 83 deletions(-) delete mode 100644 config.go delete mode 100644 global.go diff --git a/.dockerignore b/.dockerignore index caa5548..62431d5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,3 @@ openai-api-route db.sqlite -/upstreams.yaml +/config.yaml diff --git a/.gitignore b/.gitignore index caa5548..62431d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ openai-api-route db.sqlite -/upstreams.yaml +/config.yaml diff --git a/README.md b/README.md index 12acdb7..535b98d 100644 --- a/README.md +++ b/README.md @@ -83,11 +83,13 @@ Usage of ./openai-api-route: 以下是一个 `./upstreams.yaml` 文件配置示例 ```yaml -- sk: "secret_key_1" - endpoint: "https://api.openai.com/v2" -- sk: "secret_key_2" - endpoint: "https://api.openai.com/v1" - timeout: 30 +authorization: woshimima +upstreams: + - sk: "secret_key_1" + endpoint: "https://api.openai.com/v2" + - sk: "secret_key_2" + endpoint: "https://api.openai.com/v1" + timeout: 30 ``` 请注意,程序会根据情况修改 timeout 的值 diff --git a/auth.go b/auth.go index a57b615..f5f01fd 100644 --- a/auth.go +++ b/auth.go @@ -21,7 +21,7 @@ func handleAuth(c *gin.Context) error { authorization = strings.Trim(authorization[len("Bearer"):], " ") log.Println("Received authorization", authorization) - for _, auth := range strings.Split(authConfig.Value, ",") { + for _, auth := range strings.Split(config.Authorization, ",") { if authorization != strings.Trim(auth, " ") { err = errors.New("wrong authorization header") c.AbortWithError(403, err) diff --git a/config.go b/config.go deleted file mode 100644 index 21e10b3..0000000 --- a/config.go +++ /dev/null @@ -1,49 +0,0 @@ -package main - -import ( - "errors" - "log" - - "gorm.io/gorm" -) - -// K-V struct to store program's config -type ConfigKV struct { - gorm.Model - Key string `gorm:"unique"` - Value string -} - -// init db -func initconfig(db *gorm.DB) error { - var err error - - err = db.AutoMigrate(&ConfigKV{}) - if err != nil { - return err - } - - // config list and their default values - configs := make(map[string]string) - configs["authorization"] = "woshimima" - configs["policy"] = "main" - - for key, value := range configs { - kv := ConfigKV{} - err = db.Take(&kv, "key = ?", key).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Println("Missing config", key, "creating with value", value) - kv.Key = key - kv.Value = value - if err = db.Create(&kv).Error; err != nil { - return err - } - } else { - return err - } - } - } - - return nil -} diff --git a/global.go b/global.go deleted file mode 100644 index 356744a..0000000 --- a/global.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -// declare global variable - -var authConfig ConfigKV diff --git a/main.go b/main.go index 911b7ca..d639580 100644 --- a/main.go +++ b/main.go @@ -13,9 +13,12 @@ import ( "gorm.io/gorm" ) +// global config +var config Config + func main() { dbAddr := flag.String("database", "./db.sqlite", "Database address") - upstreamsFile := flag.String("upstreams", "./upstreams.yaml", "Upstreams file") + configFile := flag.String("config", "./config.yaml", "Config file") listenAddr := flag.String("addr", ":8888", "Listening address") listMode := flag.Bool("list", false, "List all upstream") noauth := flag.Bool("noauth", false, "Do not check incoming authorization header") @@ -33,21 +36,15 @@ func main() { } // load all upstreams - upstreams := readUpstreams(*upstreamsFile) - log.Println("Load upstreams number:", len(upstreams)) + config = readConfig(*configFile) + log.Println("Load upstreams number:", len(config.Upstreams)) - err = initconfig(db) - if err != nil { - log.Fatal(err) - } - - db.AutoMigrate(&OPENAI_UPSTREAM{}) db.AutoMigrate(&Record{}) log.Println("Auto migrate database done") if *listMode { fmt.Println("SK\tEndpoint") - for _, upstream := range upstreams { + for _, upstream := range config.Upstreams { fmt.Println(upstream.SK, upstream.Endpoint) } return @@ -82,9 +79,6 @@ func main() { ctx.AbortWithStatus(200) }) - // get authorization config from db - db.Take(&authConfig, "key = ?", "authorization") - engine.POST("/v1/*any", func(c *gin.Context) { record := Record{ IP: c.ClientIP(), @@ -105,15 +99,15 @@ func main() { } } - for index, upstream := range upstreams { + for index, upstream := range config.Upstreams { if upstream.Endpoint == "" || upstream.SK == "" { c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) continue } - shouldResponse := index == len(upstreams)-1 + shouldResponse := index == len(config.Upstreams)-1 - if len(upstreams) == 1 { + if len(config.Upstreams) == 1 { upstream.Timeout = 120 } diff --git a/structure.go b/structure.go index 5ff7b3b..d2ea040 100644 --- a/structure.go +++ b/structure.go @@ -7,15 +7,18 @@ import ( "gopkg.in/yaml.v3" ) -// one openai upstream contain a pair of key and endpoint +type Config struct { + Authorization string `yaml:"authorization"` + Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"` +} type OPENAI_UPSTREAM struct { SK string `yaml:"sk"` Endpoint string `yaml:"endpoint"` Timeout int64 `yaml:"timeout"` } -func readUpstreams(filepath string) []OPENAI_UPSTREAM { - var upstreams []OPENAI_UPSTREAM +func readConfig(filepath string) Config { + var config Config // read yaml file data, err := os.ReadFile(filepath) @@ -24,10 +27,10 @@ func readUpstreams(filepath string) []OPENAI_UPSTREAM { } // Unmarshal the YAML into the upstreams slice - err = yaml.Unmarshal(data, &upstreams) + err = yaml.Unmarshal(data, &config) if err != nil { log.Fatalf("Error unmarshaling YAML: %s", err) } - return upstreams + return config }