convert config to yaml file
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
openai-api-route
|
openai-api-route
|
||||||
db.sqlite
|
db.sqlite
|
||||||
/upstreams.yaml
|
/config.yaml
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,3 @@
|
|||||||
openai-api-route
|
openai-api-route
|
||||||
db.sqlite
|
db.sqlite
|
||||||
/upstreams.yaml
|
/config.yaml
|
||||||
|
|||||||
@@ -83,9 +83,11 @@ Usage of ./openai-api-route:
|
|||||||
以下是一个 `./upstreams.yaml` 文件配置示例
|
以下是一个 `./upstreams.yaml` 文件配置示例
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
- sk: "secret_key_1"
|
authorization: woshimima
|
||||||
|
upstreams:
|
||||||
|
- sk: "secret_key_1"
|
||||||
endpoint: "https://api.openai.com/v2"
|
endpoint: "https://api.openai.com/v2"
|
||||||
- sk: "secret_key_2"
|
- sk: "secret_key_2"
|
||||||
endpoint: "https://api.openai.com/v1"
|
endpoint: "https://api.openai.com/v1"
|
||||||
timeout: 30
|
timeout: 30
|
||||||
```
|
```
|
||||||
|
|||||||
2
auth.go
2
auth.go
@@ -21,7 +21,7 @@ func handleAuth(c *gin.Context) error {
|
|||||||
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
authorization = strings.Trim(authorization[len("Bearer"):], " ")
|
||||||
log.Println("Received authorization", authorization)
|
log.Println("Received authorization", authorization)
|
||||||
|
|
||||||
for _, auth := range strings.Split(authConfig.Value, ",") {
|
for _, auth := range strings.Split(config.Authorization, ",") {
|
||||||
if authorization != strings.Trim(auth, " ") {
|
if authorization != strings.Trim(auth, " ") {
|
||||||
err = errors.New("wrong authorization header")
|
err = errors.New("wrong authorization header")
|
||||||
c.AbortWithError(403, err)
|
c.AbortWithError(403, err)
|
||||||
|
|||||||
49
config.go
49
config.go
@@ -1,49 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// K-V struct to store program's config
|
|
||||||
type ConfigKV struct {
|
|
||||||
gorm.Model
|
|
||||||
Key string `gorm:"unique"`
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
// init db
|
|
||||||
func initconfig(db *gorm.DB) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
err = db.AutoMigrate(&ConfigKV{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// config list and their default values
|
|
||||||
configs := make(map[string]string)
|
|
||||||
configs["authorization"] = "woshimima"
|
|
||||||
configs["policy"] = "main"
|
|
||||||
|
|
||||||
for key, value := range configs {
|
|
||||||
kv := ConfigKV{}
|
|
||||||
err = db.Take(&kv, "key = ?", key).Error
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
log.Println("Missing config", key, "creating with value", value)
|
|
||||||
kv.Key = key
|
|
||||||
kv.Value = value
|
|
||||||
if err = db.Create(&kv).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// declare global variable
|
|
||||||
|
|
||||||
var authConfig ConfigKV
|
|
||||||
26
main.go
26
main.go
@@ -13,9 +13,12 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// global config
|
||||||
|
var config Config
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
dbAddr := flag.String("database", "./db.sqlite", "Database address")
|
dbAddr := flag.String("database", "./db.sqlite", "Database address")
|
||||||
upstreamsFile := flag.String("upstreams", "./upstreams.yaml", "Upstreams file")
|
configFile := flag.String("config", "./config.yaml", "Config file")
|
||||||
listenAddr := flag.String("addr", ":8888", "Listening address")
|
listenAddr := flag.String("addr", ":8888", "Listening address")
|
||||||
listMode := flag.Bool("list", false, "List all upstream")
|
listMode := flag.Bool("list", false, "List all upstream")
|
||||||
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
|
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
|
||||||
@@ -33,21 +36,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load all upstreams
|
// load all upstreams
|
||||||
upstreams := readUpstreams(*upstreamsFile)
|
config = readConfig(*configFile)
|
||||||
log.Println("Load upstreams number:", len(upstreams))
|
log.Println("Load upstreams number:", len(config.Upstreams))
|
||||||
|
|
||||||
err = initconfig(db)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.AutoMigrate(&OPENAI_UPSTREAM{})
|
|
||||||
db.AutoMigrate(&Record{})
|
db.AutoMigrate(&Record{})
|
||||||
log.Println("Auto migrate database done")
|
log.Println("Auto migrate database done")
|
||||||
|
|
||||||
if *listMode {
|
if *listMode {
|
||||||
fmt.Println("SK\tEndpoint")
|
fmt.Println("SK\tEndpoint")
|
||||||
for _, upstream := range upstreams {
|
for _, upstream := range config.Upstreams {
|
||||||
fmt.Println(upstream.SK, upstream.Endpoint)
|
fmt.Println(upstream.SK, upstream.Endpoint)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -82,9 +79,6 @@ func main() {
|
|||||||
ctx.AbortWithStatus(200)
|
ctx.AbortWithStatus(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
// get authorization config from db
|
|
||||||
db.Take(&authConfig, "key = ?", "authorization")
|
|
||||||
|
|
||||||
engine.POST("/v1/*any", func(c *gin.Context) {
|
engine.POST("/v1/*any", func(c *gin.Context) {
|
||||||
record := Record{
|
record := Record{
|
||||||
IP: c.ClientIP(),
|
IP: c.ClientIP(),
|
||||||
@@ -105,15 +99,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for index, upstream := range upstreams {
|
for index, upstream := range config.Upstreams {
|
||||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||||
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldResponse := index == len(upstreams)-1
|
shouldResponse := index == len(config.Upstreams)-1
|
||||||
|
|
||||||
if len(upstreams) == 1 {
|
if len(config.Upstreams) == 1 {
|
||||||
upstream.Timeout = 120
|
upstream.Timeout = 120
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
13
structure.go
13
structure.go
@@ -7,15 +7,18 @@ import (
|
|||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// one openai upstream contain a pair of key and endpoint
|
type Config struct {
|
||||||
|
Authorization string `yaml:"authorization"`
|
||||||
|
Upstreams []OPENAI_UPSTREAM `yaml:"upstreams"`
|
||||||
|
}
|
||||||
type OPENAI_UPSTREAM struct {
|
type OPENAI_UPSTREAM struct {
|
||||||
SK string `yaml:"sk"`
|
SK string `yaml:"sk"`
|
||||||
Endpoint string `yaml:"endpoint"`
|
Endpoint string `yaml:"endpoint"`
|
||||||
Timeout int64 `yaml:"timeout"`
|
Timeout int64 `yaml:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func readUpstreams(filepath string) []OPENAI_UPSTREAM {
|
func readConfig(filepath string) Config {
|
||||||
var upstreams []OPENAI_UPSTREAM
|
var config Config
|
||||||
|
|
||||||
// read yaml file
|
// read yaml file
|
||||||
data, err := os.ReadFile(filepath)
|
data, err := os.ReadFile(filepath)
|
||||||
@@ -24,10 +27,10 @@ func readUpstreams(filepath string) []OPENAI_UPSTREAM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal the YAML into the upstreams slice
|
// Unmarshal the YAML into the upstreams slice
|
||||||
err = yaml.Unmarshal(data, &upstreams)
|
err = yaml.Unmarshal(data, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error unmarshaling YAML: %s", err)
|
log.Fatalf("Error unmarshaling YAML: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return upstreams
|
return config
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user