From 7f0e0b8a9d20e2d090ca596568844120457427ae Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Mon, 10 Jul 2023 18:22:21 +0800 Subject: [PATCH] init --- .gitignore | 2 + Makefile | 2 + auth.go | 30 +++++ config.go | 49 ++++++++ feishu.go | 40 +++++++ global.go | 5 + go.mod | 39 +++++++ go.sum | 95 +++++++++++++++ main.go | 323 +++++++++++++++++++++++++++++++++++++++++++++++++++ matrix.go | 38 ++++++ record.go | 51 ++++++++ structure.go | 17 +++ 12 files changed, 691 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 auth.go create mode 100644 config.go create mode 100644 feishu.go create mode 100644 global.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 matrix.go create mode 100644 record.go create mode 100644 structure.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a838100 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +openai-api-route +db.sqlite diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..19f2943 --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +linux: + go build -v -ldflags '-linkmode=external -extldflags=-static' -tags sqlite_omit_load_extension,netgo \ No newline at end of file diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..a21c50f --- /dev/null +++ b/auth.go @@ -0,0 +1,30 @@ +package main + +import ( + "errors" + "log" + "strings" + + "github.com/gin-gonic/gin" +) + +func handleAuth(c *gin.Context) error { + var err error + + authorization := c.Request.Header.Get("Authorization") + if !strings.HasPrefix(authorization, "Bearer") { + err = errors.New("authorization header should start with 'Bearer'") + c.AbortWithError(403, err) + return err + } + + authorization = strings.Trim(authorization[len("Bearer"):], " ") + log.Println("Received authorization", authorization) + if authorization != authConfig.Value { + err = errors.New("wrong authorization header") + c.AbortWithError(403, err) + return err + } + + return nil +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..81b6438 --- /dev/null +++ b/config.go @@ -0,0 +1,49 @@ +package main + +import ( + "errors" + "log" + + "gorm.io/gorm" +) + +// K-V struct to store program's config +type ConfigKV struct { + gorm.Model + Key string `gorm:"unique"` + Value string +} + +// init db +func initconfig(db *gorm.DB) error { + var err error + + err = db.AutoMigrate(&ConfigKV{}) + if err != nil { + return err + } + + // config list and their default values + configs := make(map[string]string) + configs["authorization"] = "woshimima" + configs["policy"] = "random" + + for key, value := range configs { + kv := ConfigKV{} + err = db.Take(&kv, "key = ?", key).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Println("Missing config", key, "creating with value", value) + kv.Key = key + kv.Value = value + if err = db.Create(&kv).Error; err != nil { + return err + } + } else { + return err + } + } + } + + return nil +} diff --git a/feishu.go b/feishu.go new file mode 100644 index 0000000..8d16742 --- /dev/null +++ b/feishu.go @@ -0,0 +1,40 @@ +package main + +import ( + "bytes" + "encoding/json" + "log" + "net/http" + "os" +) + +type FeishuMessage struct { + MsgType string `json:"msg_type"` + Content FeishuMessageContent `json:"content"` +} +type FeishuMessageContent struct { + Text string `json:"text"` +} + +func sendFeishuMessage(content string) error { + messageBytes, err := json.Marshal(&FeishuMessage{ + MsgType: "text", + Content: FeishuMessageContent{ + Text: content, + }, + }) + if err != nil { + log.Println("Failed to send feishu message", err) + } + FEISHU_WEBHOOK := os.Getenv("FEISHU_WEBOOK") + if FEISHU_WEBHOOK == "" { + log.Println("FEISHU_WEBOOK environment not set") + return nil + } + http.Post( + FEISHU_WEBHOOK, + "application/json", + bytes.NewReader(messageBytes), + ) + return nil +} diff --git a/global.go b/global.go new file mode 100644 index 0000000..356744a --- /dev/null +++ b/global.go @@ -0,0 +1,5 @@ +package main + +// declare global variable + +var authConfig ConfigKV diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..10322b5 --- /dev/null +++ b/go.mod @@ -0,0 +1,39 @@ +module openai-api-route + +go 1.20 + +require ( + github.com/gin-gonic/gin v1.9.1 + gorm.io/driver/sqlite v1.5.2 + gorm.io/gorm v1.25.2 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0d393af --- /dev/null +++ b/go.sum @@ -0,0 +1,95 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc= +gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho= +gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go new file mode 100644 index 0000000..c1df708 --- /dev/null +++ b/main.go @@ -0,0 +1,323 @@ +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func main() { + dbAddr := flag.String("database", "./db.sqlite", "Database address") + listenAddr := flag.String("addr", ":8888", "Listening address") + addMode := flag.Bool("add", false, "Add an OpenAI upstream") + listMode := flag.Bool("list", false, "List all upstream") + sk := flag.String("sk", "", "OpenAI API key (sk-xxxxx)") + noauth := flag.Bool("noauth", false, "Do not check incoming authorization header") + endpoint := flag.String("endpoint", "https://api.openai.com/v1", "OpenAI API base") + flag.Parse() + + log.Println("Service starting") + + // connect to database + db, err := gorm.Open(sqlite.Open(*dbAddr), &gorm.Config{ + PrepareStmt: true, + SkipDefaultTransaction: true, + }) + if err != nil { + log.Fatal("Failed to connect to database") + } + + err = initconfig(db) + if err != nil { + log.Fatal(err) + } + + db.AutoMigrate(&OPENAI_UPSTREAM{}) + db.AutoMigrate(&UserMessage{}) + log.Println("Auto migrate database done") + + if *addMode { + if *sk == "" { + log.Fatal("Missing --sk flag") + } + newUpstream := OPENAI_UPSTREAM{} + newUpstream.SK = *sk + newUpstream.Endpoint = *endpoint + err = db.Create(&newUpstream).Error + if err != nil { + log.Fatal("Can not add upstream", err) + } + log.Println("Successuflly add upstream", *sk, *endpoint) + return + } + + if *listMode { + result := make([]OPENAI_UPSTREAM, 0) + db.Find(&result) + fmt.Println("SK\tEndpoint\tSuccess\tFailed\tLast Success Time") + for _, upstream := range result { + fmt.Println(upstream.SK, upstream.Endpoint, upstream.SuccessCount, upstream.FailedCount, upstream.LastCallSuccessTime) + } + return + } + + // init gin + engine := gin.Default() + + // error handle middleware + engine.Use(func(c *gin.Context) { + c.Next() + if len(c.Errors) == 0 { + return + } + errText := strings.Join(c.Errors.Errors(), "\n") + c.JSON(-1, gin.H{ + "error": errText, + }) + }) + + // get authorization config from db + db.Take(&authConfig, "key = ?", "authorization") + + engine.POST("/v1/*any", func(c *gin.Context) { + // check authorization header + if !*noauth { + if handleAuth(c) != nil { + return + } + } + + // get load balance policy + policy := ConfigKV{Value: "main"} + db.Take(&policy, "key = ?", "policy") + log.Println("policy is", policy.Value) + + upstream := OPENAI_UPSTREAM{} + + // choose openai upstream + switch policy.Value { + case "main": + db.Order("failed_count, success_count desc").First(&upstream) + case "random": + // randomly select one upstream + db.Order("random()").Take(&upstream) + case "random_available": + // randomly select one non-failed upstream + db.Where("failed_count = ?", 0).Order("random()").Take(&upstream) + case "round_robin": + // iterates each upstream + db.Order("last_call_success_time").First(&upstream) + case "round_robin_available": + db.Where("failed_count = ?", 0).Order("last_call_success_time").First(&upstream) + default: + c.AbortWithError(500, fmt.Errorf("unknown load balance policy '%s'", policy.Value)) + } + + // do check + log.Println("upstream is", upstream.SK, upstream.Endpoint) + if upstream.Endpoint == "" || upstream.SK == "" { + c.AbortWithError(500, fmt.Errorf("invaild upstream from '%s' policy", policy.Value)) + return + } + + // reverse proxy + remote, err := url.Parse(upstream.Endpoint) + if err != nil { + c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL")) + return + } + proxy := httputil.NewSingleHostReverseProxy(remote) + proxy.Director = nil + proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) { + in := proxyRequest.In + out := proxyRequest.Out + + // read request body + body, err := io.ReadAll(in.Body) + if err != nil { + c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error())) + return + } + + // record chat message from user + go recordUserMessage(c, db, body) + + out.Body = io.NopCloser(bytes.NewReader(body)) + + out.Host = remote.Host + out.URL.Scheme = remote.Scheme + out.URL.Host = remote.Host + out.URL.Path = in.URL.Path + out.Header = http.Header{} + out.Header.Set("Host", remote.Host) + out.Header.Set("Authorization", "Bearer "+upstream.SK) + out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + } + proxy.ModifyResponse = func(r *http.Response) error { + if r.StatusCode != 200 { + body, err := io.ReadAll(r.Body) + if err != nil { + return errors.New("failed to read response from upstream " + err.Error()) + } + return fmt.Errorf("upstream return '%s' with '%s'", r.Status, string(body)) + } + // count success + go db.Model(&upstream).Updates(map[string]interface{}{ + "success_count": gorm.Expr("success_count + ?", 1), + "last_call_success_time": time.Now(), + }) + return nil + } + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + log.Println("Error", err, upstream.SK, upstream.Endpoint) + + // abort to error handle + c.AbortWithError(502, err) + + // count failed + if err.Error() != "context canceled" { + go db.Model(&upstream).Update("failed_count", gorm.Expr("failed_count + ?", 1)) + } + + // send notification + upstreams := []OPENAI_UPSTREAM{} + db.Find(&upstreams) + upstreamDescriptions := make([]string, 0) + for _, upstream := range upstreams { + upstreamDescriptions = append(upstreamDescriptions, fmt.Sprintf("ID: %d, %s: %s 成功次数: %d, 失败次数: %d, 最后成功调用: %s", + upstream.ID, upstream.SK, upstream.Endpoint, upstream.SuccessCount, upstream.FailedCount, upstream.LastCallSuccessTime, + )) + } + content := fmt.Sprintf("[%s] OpenAI 转发出错 ID: %d... 密钥: [%s] 上游: [%s] 错误: %s\n---\n%s", + c.ClientIP(), + upstream.ID, upstream.SK[:10], upstream.Endpoint, err.Error(), + strings.Join(upstreamDescriptions, "\n"), + ) + go sendMatrixMessage(content) + if err.Error() != "context canceled" { + go sendFeishuMessage(content) + } + + log.Println("response is", r.Response) + } + proxy.ServeHTTP(c.Writer, c.Request) + }) + + // --------------------------------- + // admin APIs + engine.POST("/admin/login", func(c *gin.Context) { + // check authorization headers + if handleAuth(c) != nil { + return + } + c.JSON(200, gin.H{ + "message": "success", + }) + }) + engine.GET("/admin/upstreams", func(c *gin.Context) { + // check authorization headers + if handleAuth(c) != nil { + return + } + upstreams := make([]OPENAI_UPSTREAM, 0) + db.Find(&upstreams) + c.JSON(200, upstreams) + }) + engine.POST("/admin/upstreams", func(c *gin.Context) { + // check authorization headers + if handleAuth(c) != nil { + return + } + newUpstream := OPENAI_UPSTREAM{} + err := c.BindJSON(&newUpstream) + if err != nil { + c.AbortWithError(502, errors.New("can't parse OPENAI_UPSTREAM object")) + return + } + if newUpstream.SK == "" || newUpstream.Endpoint == "" { + c.AbortWithError(403, errors.New("can't create new OPENAI_UPSTREAM with empty sk or endpoint")) + return + } + log.Println("Saveing new OPENAI_UPSTREAM", newUpstream) + err = db.Create(&newUpstream).Error + if err != nil { + c.AbortWithError(403, err) + return + } + }) + engine.DELETE("/admin/upstreams/:id", func(ctx *gin.Context) { + // check authorization headers + if handleAuth(ctx) != nil { + return + } + id, err := strconv.Atoi(ctx.Param("id")) + if err != nil { + ctx.AbortWithError(502, err) + return + } + upstream := OPENAI_UPSTREAM{} + upstream.ID = uint(id) + db.Delete(&upstream) + ctx.JSON(200, gin.H{ + "message": "success", + }) + }) + engine.PUT("/admin/upstreams/:id", func(c *gin.Context) { + // check authorization headers + if handleAuth(c) != nil { + return + } + upstream := OPENAI_UPSTREAM{} + err := c.BindJSON(&upstream) + if err != nil { + c.AbortWithError(502, errors.New("can't parse OPENAI_UPSTREAM object")) + return + } + if upstream.SK == "" || upstream.Endpoint == "" { + c.AbortWithError(403, errors.New("can't create new OPENAI_UPSTREAM with empty sk or endpoint")) + return + } + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.AbortWithError(502, err) + return + } + upstream.ID = uint(id) + log.Println("Saveing new OPENAI_UPSTREAM", upstream) + err = db.Create(&upstream).Error + if err != nil { + c.AbortWithError(403, err) + return + } + c.JSON(200, gin.H{ + "message": "success", + }) + }) + engine.GET("/admin/user_messages", func(c *gin.Context) { + // check authorization headers + if handleAuth(c) != nil { + return + } + userMessages := []UserMessage{} + err := db.Order("id desc").Limit(100).Find(&userMessages).Error + if err != nil { + c.AbortWithError(502, err) + return + } + c.JSON(200, userMessages) + }) + engine.Run(*listenAddr) +} diff --git a/matrix.go b/matrix.go new file mode 100644 index 0000000..4d62611 --- /dev/null +++ b/matrix.go @@ -0,0 +1,38 @@ +package main + +import ( + "bytes" + "encoding/json" + "log" + "net/http" + "os" +) + +type MatrixMessage struct { + Message string `json:"message"` + Body string `json:"body"` +} + +func sendMatrixMessage(content string) error { + messageBytes, marshalErr := json.Marshal(&MatrixMessage{ + Message: "m.text", + Body: content, + }) + if marshalErr != nil { + log.Println("Failed to send matrix message", marshalErr) + return marshalErr + } + + MATRIX_API := os.Getenv("MATRIX_API") + if MATRIX_API == "" { + log.Println("MATRIX_API envitonment not set") + return nil + } + + http.Post( + MATRIX_API, + "application/json", + bytes.NewReader(messageBytes), + ) + return nil +} diff --git a/record.go b/record.go new file mode 100644 index 0000000..08b56e4 --- /dev/null +++ b/record.go @@ -0,0 +1,51 @@ +package main + +import ( + "encoding/json" + "log" + "strings" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type UserMessage struct { + gorm.Model + ModelName string + Content string +} + +// sturcture to parse request +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} +type Message struct { + Content string `json:"content"` +} + +func recordUserMessage(c *gin.Context, db *gorm.DB, body []byte) { + bodyJson := ChatRequest{} + err := json.Unmarshal(body, &bodyJson) + if err != nil { + c.AbortWithError(502, err) + return + } + model := bodyJson.Model + if !strings.HasPrefix(model, "gpt-") { + return + } + // get message content + if len(bodyJson.Messages) == 0 { + return + } + content := bodyJson.Messages[len(bodyJson.Messages)-1].Content + + log.Println("Record user message", model, content) + + userMessage := UserMessage{ + ModelName: model, + Content: content, + } + db.Create(&userMessage) +} diff --git a/structure.go b/structure.go new file mode 100644 index 0000000..ca29e9b --- /dev/null +++ b/structure.go @@ -0,0 +1,17 @@ +package main + +import ( + "time" + + "gorm.io/gorm" +) + +// one openai upstream contain a pair of key and endpoint +type OPENAI_UPSTREAM struct { + gorm.Model + SK string `gorm:"index:idx_sk_endpoint,unique"` // key + Endpoint string `gorm:"index:idx_sk_endpoint,unique"` // endpoint + SuccessCount int64 + FailedCount int64 + LastCallSuccessTime time.Time +}