Compare commits

...

6 Commits

Author SHA1 Message Date
2bbe98e694 fix duplicated cors headers 2024-01-04 19:03:58 +08:00
9fdbf259c0 truncate requset body if too long 2024-01-04 18:41:26 +08:00
97926087bb async record request 2024-01-04 18:37:01 +08:00
fc5a8d55fa fix: process error content type 2023-12-22 15:08:24 +08:00
b1e3a97aad fix: cors and content-type on error 2023-12-22 14:24:11 +08:00
04a2e4c12d use cors middleware 2023-12-22 13:23:41 +08:00
4 changed files with 70 additions and 24 deletions

View File

@@ -14,7 +14,6 @@ func handleAuth(c *gin.Context) error {
authorization := c.Request.Header.Get("Authorization")
if !strings.HasPrefix(authorization, "Bearer") {
err = errors.New("authorization header should start with 'Bearer'")
c.AbortWithError(403, err)
return err
}
@@ -24,7 +23,6 @@ func handleAuth(c *gin.Context) error {
for _, auth := range strings.Split(config.Authorization, ",") {
if authorization != strings.Trim(auth, " ") {
err = errors.New("wrong authorization header")
c.AbortWithError(403, err)
return err
}
}

22
cors.go Normal file
View File

@@ -0,0 +1,22 @@
package main
import (
"github.com/gin-gonic/gin"
)
// this function is aborded
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// set cors header
header := c.Request.Header
if header.Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if header.Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if header.Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
}
}

46
main.go
View File

@@ -73,6 +73,9 @@ func main() {
m.SetMetricPath("/v1/metrics")
m.Use(engine)
// CORS middleware
// engine.Use(corsMiddleware())
// error handle middleware
engine.Use(func(c *gin.Context) {
c.Next()
@@ -87,10 +90,10 @@ func main() {
// CORS handler
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
header := ctx.Writer.Header()
header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
// set cros header
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
ctx.AbortWithStatus(200)
})
@@ -106,16 +109,21 @@ func main() {
Authorization: c.Request.Header.Get("Authorization"),
UserAgent: c.Request.Header.Get("User-Agent"),
}
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
/*
defer func() {
if err := recover(); err != nil {
log.Println("Error:", err)
c.AbortWithError(500, fmt.Errorf("%s", err))
}
}()
*/
// check authorization header
if !*noauth {
if handleAuth(c) != nil {
err := handleAuth(c)
if err != nil {
c.Header("Content-Type", "application/json")
c.AbortWithError(403, err)
return
}
}
@@ -143,9 +151,19 @@ func main() {
log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
}
// async record request
go func() {
// turncate request if too long
if len(record.Body) > 1024*128 {
log.Println("Warning: Truncate request body")
record.Body = record.Body[:1024*128]
}
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
}
}()
if record.Status != 200 {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
go sendFeishuMessage(errMessage)

View File

@@ -28,7 +28,6 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// reverse proxy
remote, err := url.Parse(upstream.Endpoint)
if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return err
}
@@ -38,6 +37,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
proxy.Director = nil
var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In
ctx, cancel := context.WithCancel(context.Background())
@@ -48,7 +48,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// read request body
inBody, err = io.ReadAll(in.Body)
if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error()))
errCtx = errors.New("reverse proxy middleware failed to read request body " + err.Error())
return
}
@@ -80,6 +80,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
log.Println("Timeout upstream", upstream.Endpoint)
errCtx = errors.New("timeout")
if shouldResponse {
c.Header("Content-Type", "application/json")
c.AbortWithError(502, errCtx)
}
cancel()
@@ -109,16 +110,22 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
haveResponse = true
record.ResponseTime = time.Now().Sub(record.CreatedAt)
record.Status = r.StatusCode
// handle reverse proxy cors header if upstream do not set that
if r.Header.Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if r.Header.Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if r.Header.Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response")
}
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body)
@@ -144,6 +151,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// abort to error handle
if shouldResponse {
c.Header("Content-Type", "application/json")
c.AbortWithError(502, err)
}