Files
openai-api-route/main.go

137 lines
3.3 KiB
Go

package main
import (
"flag"
"fmt"
"log"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/penglongli/gin-metrics/ginmetrics"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// global config
var config Config
func main() {
dbAddr := flag.String("database", "./db.sqlite", "Database address")
configFile := flag.String("config", "./config.yaml", "Config file")
listenAddr := flag.String("addr", ":8888", "Listening address")
listMode := flag.Bool("list", false, "List all upstream")
noauth := flag.Bool("noauth", false, "Do not check incoming authorization header")
flag.Parse()
log.Println("Service starting")
// connect to database
db, err := gorm.Open(sqlite.Open(*dbAddr), &gorm.Config{
PrepareStmt: true,
SkipDefaultTransaction: true,
})
if err != nil {
log.Fatal("Failed to connect to database")
}
// load all upstreams
config = readConfig(*configFile)
log.Println("Load upstreams number:", len(config.Upstreams))
db.AutoMigrate(&Record{})
log.Println("Auto migrate database done")
if *listMode {
fmt.Println("SK\tEndpoint")
for _, upstream := range config.Upstreams {
fmt.Println(upstream.SK, upstream.Endpoint)
}
return
}
// init gin
engine := gin.Default()
// metrics
m := ginmetrics.GetMonitor()
m.SetMetricPath("/v1/metrics")
m.Use(engine)
// error handle middleware
engine.Use(func(c *gin.Context) {
c.Next()
if len(c.Errors) == 0 {
return
}
errText := strings.Join(c.Errors.Errors(), "\n")
c.JSON(-1, gin.H{
"error": errText,
})
})
// CORS handler
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
header := ctx.Writer.Header()
header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
ctx.AbortWithStatus(200)
})
engine.POST("/v1/*any", func(c *gin.Context) {
record := Record{
IP: c.ClientIP(),
CreatedAt: time.Now(),
Authorization: c.Request.Header.Get("Authorization"),
}
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
// check authorization header
if !*noauth {
if handleAuth(c) != nil {
return
}
}
for index, upstream := range config.Upstreams {
if upstream.Endpoint == "" || upstream.SK == "" {
c.AbortWithError(500, fmt.Errorf("invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
continue
}
shouldResponse := index == len(config.Upstreams)-1
if len(config.Upstreams) == 1 {
upstream.Timeout = 120
}
err = processRequest(c, &upstream, &record, shouldResponse)
if err != nil {
log.Println("Error from upstream", upstream.Endpoint, "should retry", err)
continue
}
break
}
log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
}
if record.Status != 200 {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
go sendFeishuMessage(errMessage)
go sendMatrixMessage(errMessage)
}
})
engine.Run(*listenAddr)
}