Compare commits

...

6 Commits

Author SHA1 Message Date
2bbe98e694 fix duplicated cors headers 2024-01-04 19:03:58 +08:00
9fdbf259c0 truncate requset body if too long 2024-01-04 18:41:26 +08:00
97926087bb async record request 2024-01-04 18:37:01 +08:00
fc5a8d55fa fix: process error content type 2023-12-22 15:08:24 +08:00
b1e3a97aad fix: cors and content-type on error 2023-12-22 14:24:11 +08:00
04a2e4c12d use cors middleware 2023-12-22 13:23:41 +08:00
4 changed files with 70 additions and 24 deletions

View File

@@ -14,7 +14,6 @@ func handleAuth(c *gin.Context) error {
authorization := c.Request.Header.Get("Authorization") authorization := c.Request.Header.Get("Authorization")
if !strings.HasPrefix(authorization, "Bearer") { if !strings.HasPrefix(authorization, "Bearer") {
err = errors.New("authorization header should start with 'Bearer'") err = errors.New("authorization header should start with 'Bearer'")
c.AbortWithError(403, err)
return err return err
} }
@@ -24,7 +23,6 @@ func handleAuth(c *gin.Context) error {
for _, auth := range strings.Split(config.Authorization, ",") { for _, auth := range strings.Split(config.Authorization, ",") {
if authorization != strings.Trim(auth, " ") { if authorization != strings.Trim(auth, " ") {
err = errors.New("wrong authorization header") err = errors.New("wrong authorization header")
c.AbortWithError(403, err)
return err return err
} }
} }

22
cors.go Normal file
View File

@@ -0,0 +1,22 @@
package main
import (
"github.com/gin-gonic/gin"
)
// this function is aborded
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// set cors header
header := c.Request.Header
if header.Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if header.Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if header.Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
}
}

46
main.go
View File

@@ -73,6 +73,9 @@ func main() {
m.SetMetricPath("/v1/metrics") m.SetMetricPath("/v1/metrics")
m.Use(engine) m.Use(engine)
// CORS middleware
// engine.Use(corsMiddleware())
// error handle middleware // error handle middleware
engine.Use(func(c *gin.Context) { engine.Use(func(c *gin.Context) {
c.Next() c.Next()
@@ -87,10 +90,10 @@ func main() {
// CORS handler // CORS handler
engine.OPTIONS("/v1/*any", func(ctx *gin.Context) { engine.OPTIONS("/v1/*any", func(ctx *gin.Context) {
header := ctx.Writer.Header() // set cros header
header.Set("Access-Control-Allow-Origin", "*") ctx.Header("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH") ctx.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type") ctx.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
ctx.AbortWithStatus(200) ctx.AbortWithStatus(200)
}) })
@@ -106,16 +109,21 @@ func main() {
Authorization: c.Request.Header.Get("Authorization"), Authorization: c.Request.Header.Get("Authorization"),
UserAgent: c.Request.Header.Get("User-Agent"), UserAgent: c.Request.Header.Get("User-Agent"),
} }
defer func() { /*
if err := recover(); err != nil { defer func() {
log.Println("Error:", err) if err := recover(); err != nil {
c.AbortWithError(500, fmt.Errorf("%s", err)) log.Println("Error:", err)
} c.AbortWithError(500, fmt.Errorf("%s", err))
}() }
}()
*/
// check authorization header // check authorization header
if !*noauth { if !*noauth {
if handleAuth(c) != nil { err := handleAuth(c)
if err != nil {
c.Header("Content-Type", "application/json")
c.AbortWithError(403, err)
return return
} }
} }
@@ -143,9 +151,19 @@ func main() {
log.Println("Record result:", record.Status, record.Response) log.Println("Record result:", record.Status, record.Response)
record.ElapsedTime = time.Now().Sub(record.CreatedAt) record.ElapsedTime = time.Now().Sub(record.CreatedAt)
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record) // async record request
} go func() {
// turncate request if too long
if len(record.Body) > 1024*128 {
log.Println("Warning: Truncate request body")
record.Body = record.Body[:1024*128]
}
if db.Create(&record).Error != nil {
log.Println("Error to save record:", record)
}
}()
if record.Status != 200 { if record.Status != 200 {
errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response) errMessage := fmt.Sprintf("IP: %s request %s error %d with %s", record.IP, record.Model, record.Status, record.Response)
go sendFeishuMessage(errMessage) go sendFeishuMessage(errMessage)

View File

@@ -28,7 +28,6 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// reverse proxy // reverse proxy
remote, err := url.Parse(upstream.Endpoint) remote, err := url.Parse(upstream.Endpoint)
if err != nil { if err != nil {
c.AbortWithError(500, errors.New("can't parse reverse proxy remote URL"))
return err return err
} }
@@ -38,6 +37,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
proxy.Director = nil proxy.Director = nil
var inBody []byte var inBody []byte
proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) { proxy.Rewrite = func(proxyRequest *httputil.ProxyRequest) {
in := proxyRequest.In in := proxyRequest.In
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -48,7 +48,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// read request body // read request body
inBody, err = io.ReadAll(in.Body) inBody, err = io.ReadAll(in.Body)
if err != nil { if err != nil {
c.AbortWithError(502, errors.New("reverse proxy middleware failed to read request body "+err.Error())) errCtx = errors.New("reverse proxy middleware failed to read request body " + err.Error())
return return
} }
@@ -80,6 +80,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
log.Println("Timeout upstream", upstream.Endpoint) log.Println("Timeout upstream", upstream.Endpoint)
errCtx = errors.New("timeout") errCtx = errors.New("timeout")
if shouldResponse { if shouldResponse {
c.Header("Content-Type", "application/json")
c.AbortWithError(502, errCtx) c.AbortWithError(502, errCtx)
} }
cancel() cancel()
@@ -109,16 +110,22 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
haveResponse = true haveResponse = true
record.ResponseTime = time.Now().Sub(record.CreatedAt) record.ResponseTime = time.Now().Sub(record.CreatedAt)
record.Status = r.StatusCode record.Status = r.StatusCode
// handle reverse proxy cors header if upstream do not set that
if r.Header.Get("Access-Control-Allow-Origin") == "" {
c.Header("Access-Control-Allow-Origin", "*")
}
if r.Header.Get("Access-Control-Allow-Methods") == "" {
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
}
if r.Header.Get("Access-Control-Allow-Headers") == "" {
c.Header("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
}
if !shouldResponse && r.StatusCode != 200 { if !shouldResponse && r.StatusCode != 200 {
log.Println("upstream return not 200 and should not response", r.StatusCode) log.Println("upstream return not 200 and should not response", r.StatusCode)
return errors.New("upstream return not 200 and should not response") return errors.New("upstream return not 200 and should not response")
} }
r.Header.Del("Access-Control-Allow-Origin")
r.Header.Del("Access-Control-Allow-Methods")
r.Header.Del("Access-Control-Allow-Headers")
r.Header.Set("Access-Control-Allow-Origin", "*")
r.Header.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, PATCH")
r.Header.Set("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
if r.StatusCode != 200 { if r.StatusCode != 200 {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
@@ -144,6 +151,7 @@ func processRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, s
// abort to error handle // abort to error handle
if shouldResponse { if shouldResponse {
c.Header("Content-Type", "application/json")
c.AbortWithError(502, err) c.AbortWithError(502, err)
} }