diff --git a/main.go b/main.go index 09eacde..6c1df2b 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/json" "errors" "flag" "fmt" @@ -14,7 +15,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" "gorm.io/driver/sqlite" "gorm.io/gorm" ) @@ -96,8 +96,11 @@ func main() { db.Take(&authConfig, "key = ?", "authorization") engine.POST("/v1/*any", func(c *gin.Context) { - begin := time.Now() - trackID := uuid.New() + record := Record{ + IP: c.ClientIP(), + CreatedAt: time.Now(), + } + // check authorization header if !*noauth { if handleAuth(c) != nil { @@ -138,6 +141,8 @@ func main() { return } + record.UpstreamID = upstream.ID + // reverse proxy remote, err := url.Parse(upstream.Endpoint) if err != nil { @@ -158,7 +163,7 @@ func main() { } // record chat message from user - go recordUserMessage(c, db, trackID, body) + record.Body = string(body) out.Body = io.NopCloser(bytes.NewReader(body)) @@ -174,18 +179,17 @@ func main() { var buf bytes.Buffer var contentType string proxy.ModifyResponse = func(r *http.Response) error { + record.Status = r.StatusCode if r.StatusCode != 200 { body, err := io.ReadAll(r.Body) if err != nil { - return errors.New("failed to read response from upstream " + err.Error()) + record.Response = "failed to read response from upstream " + err.Error() + return errors.New(record.Response) } - return fmt.Errorf("upstream return '%s' with '%s'", r.Status, string(body)) + record.Response = fmt.Sprintf("upstream return '%s' with '%s'", r.Status, string(body)) + return fmt.Errorf(record.Response) } // count success - go db.Model(&upstream).Updates(map[string]interface{}{ - "success_count": gorm.Expr("success_count + ?", 1), - "last_call_success_time": time.Now(), - }) r.Body = io.NopCloser(io.TeeReader(r.Body, &buf)) contentType = r.Header.Get("content-type") return nil @@ -212,8 +216,6 @@ func main() { ) go sendMatrixMessage(content) if err.Error() != "context canceled" && r.Response.StatusCode != 400 { - // count failed - go db.Model(&upstream).Update("failed_count", gorm.Expr("failed_count + ?", 1)) go sendFeishuMessage(content) } @@ -222,9 +224,63 @@ func main() { proxy.ServeHTTP(c.Writer, c.Request) resp, err := io.ReadAll(io.NopCloser(&buf)) if err != nil { - log.Println("Failed to read from response tee buffer", err) + record.Response = "failed to read response from upstream " + err.Error() + log.Println(record.Response) + } else { + + // record response + // stream mode + if strings.HasPrefix(contentType, "text/event-stream") { + for _, line := range strings.Split(string(resp), "\n") { + chunk := StreamModeChunk{} + line = strings.TrimPrefix(line, "data:") + line = strings.TrimSpace(line) + if line == "" { + continue + } + + err := json.Unmarshal([]byte(line), &chunk) + if err != nil { + log.Println(err) + continue + } + + if len(chunk.Choices) == 0 { + continue + } + record.Response += chunk.Choices[0].Delta.Content + } + } else if strings.HasPrefix(contentType, "application/json") { + var fetchResp FetchModeResponse + err := json.Unmarshal(resp, &fetchResp) + if err != nil { + log.Println("Error parsing fetch response:", err) + return + } + if !strings.HasPrefix(fetchResp.Model, "gpt-") { + log.Println("Not GPT model, skip recording response:", fetchResp.Model) + return + } + if len(fetchResp.Choices) == 0 { + log.Println("Error: fetch response choice length is 0") + return + } + record.Response = fetchResp.Choices[0].Message.Content + } else { + log.Println("Unknown content type", contentType) + return + } + } + + if len(record.Body) > 1024*512 { + record.Body = "" + } + + log.Println("Record result:", record.Response) + record.ElapsedTime = time.Now().Sub(record.CreatedAt) + if db.Create(&record).Error != nil { + log.Println("Error to save record:", record) } - go recordAssistantResponse(contentType, db, trackID, resp, time.Now().Sub(begin)) }) engine.Run(*listenAddr) diff --git a/record.go b/record.go index 30af0a5..5ff89ca 100644 --- a/record.go +++ b/record.go @@ -6,31 +6,19 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" "github.com/google/uuid" "gorm.io/gorm" ) type Record struct { - ID uuid.UUID `gorm:"type:uuid"` + ID int64 `gorm:"primaryKey,autoIncrement"` CreatedAt time.Time IP string - Body string + Body string `gorm:"serializer:json"` Response string ElapsedTime time.Duration -} - -func recordUserMessage(c *gin.Context, db *gorm.DB, trackID uuid.UUID, body []byte) { - bodyStr := string(body) - requestRecord := Record{ - Body: bodyStr, - ID: trackID, - IP: c.ClientIP(), - } - err := db.Create(&requestRecord).Error - if err != nil { - log.Println("Error record request:", err) - } + Status int + UpstreamID uint } type StreamModeChunk struct {