From fd390577d5f00a76a3d54311cfa217f37e52ebde Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 18 Jul 2023 18:35:53 +0800 Subject: [PATCH] record response and uuid --- go.mod | 1 + go.sum | 2 ++ main.go | 17 +++++++-- record.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 115 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 10322b5..f1786df 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/go.sum b/go.sum index 0d393af..6621c14 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= diff --git a/main.go b/main.go index 7717bcd..87b17c7 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "gorm.io/driver/sqlite" "gorm.io/gorm" ) @@ -46,7 +47,7 @@ func main() { } db.AutoMigrate(&OPENAI_UPSTREAM{}) - db.AutoMigrate(&RequestRecord{}) + db.AutoMigrate(&Record{}) log.Println("Auto migrate database done") if *addMode { @@ -96,6 +97,7 @@ func main() { db.Take(&authConfig, "key = ?", "authorization") engine.POST("/v1/*any", func(c *gin.Context) { + trackID := uuid.New() // check authorization header if !*noauth { if handleAuth(c) != nil { @@ -156,7 +158,7 @@ func main() { } // record chat message from user - go recordUserMessage(c, db, body) + go recordUserMessage(c, db, trackID, body) out.Body = io.NopCloser(bytes.NewReader(body)) @@ -169,6 +171,8 @@ func main() { out.Header.Set("Authorization", "Bearer "+upstream.SK) out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) } + var buf bytes.Buffer + var contentType string proxy.ModifyResponse = func(r *http.Response) error { if r.StatusCode != 200 { body, err := io.ReadAll(r.Body) @@ -182,6 +186,8 @@ func main() { "success_count": gorm.Expr("success_count + ?", 1), "last_call_success_time": time.Now(), }) + r.Body = io.NopCloser(io.TeeReader(r.Body, &buf)) + contentType = r.Header.Get("content-type") return nil } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { @@ -214,6 +220,11 @@ func main() { log.Println("response is", r.Response) } proxy.ServeHTTP(c.Writer, c.Request) + resp, err := io.ReadAll(io.NopCloser(&buf)) + if err != nil { + log.Println("Failed to read from response tee buffer", err) + } + go recordAssistantResponse(contentType, db, trackID, resp) }) // --------------------------------- @@ -311,7 +322,7 @@ func main() { if handleAuth(c) != nil { return } - requestRecords := []RequestRecord{} + requestRecords := []Record{} err := db.Order("id desc").Limit(100).Find(&requestRecords).Error if err != nil { c.AbortWithError(502, err) diff --git a/record.go b/record.go index 342219c..65b14a2 100644 --- a/record.go +++ b/record.go @@ -1,24 +1,118 @@ package main import ( + "encoding/json" "log" + "strings" "github.com/gin-gonic/gin" + "github.com/google/uuid" "gorm.io/gorm" ) -type RequestRecord struct { +type Record struct { gorm.Model - Body string + ID uuid.UUID `gorm:"type:uuid"` + Body string + Response string } -func recordUserMessage(c *gin.Context, db *gorm.DB, body []byte) { +func recordUserMessage(c *gin.Context, db *gorm.DB, trackID uuid.UUID, body []byte) { bodyStr := string(body) - requestRecord := RequestRecord{ + requestRecord := Record{ Body: bodyStr, + ID: trackID, } err := db.Create(&requestRecord).Error if err != nil { log.Println("Error record request:", err) } } + +type StreamModeChunk struct { + Choices []StreamModeChunkChoice `json:"choices"` +} +type StreamModeChunkChoice struct { + Delta StreamModeDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} +type StreamModeDelta struct { + Content string `json:"content"` +} + +type FetchModeResponse struct { + Model string `json:"model"` + Choices []FetchModeChoice `json:"choices"` + Usage FetchModeUsage `json:"usage"` +} +type FetchModeChoice struct { + Message FetchModeMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} +type FetchModeMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} +type FetchModeUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +func recordAssistantResponse(contentType string, db *gorm.DB, trackID uuid.UUID, body []byte) { + result := "" + // stream mode + if strings.HasPrefix(contentType, "text/event-stream") { + resp := string(body) + var chunk StreamModeChunk + for _, line := range strings.Split(resp, "\n") { + line = strings.TrimPrefix(line, "data:") + line = strings.TrimSpace(line) + if line == "" { + continue + } + + err := json.Unmarshal([]byte(line), &chunk) + if err != nil { + log.Println(err) + continue + } + + if len(chunk.Choices) == 0 { + continue + } + result += chunk.Choices[0].Delta.Content + log.Println(line) + } + } else if strings.HasPrefix(contentType, "application/json") { + var fetchResp FetchModeResponse + err := json.Unmarshal(body, &fetchResp) + if err != nil { + log.Println("Error parsing fetch response:", err) + return + } + if !strings.HasPrefix(fetchResp.Model, "gpt-") { + log.Println("Not GPT model, skip recording response") + return + } + if len(fetchResp.Choices) == 0 { + log.Println("Error: fetch response choice length is 0") + return + } + result = fetchResp.Choices[0].Message.Content + } else { + log.Println("Unknown content type", contentType) + return + } + log.Println("Record result:", result) + record := Record{} + if db.Find(&record, "id = ?", trackID).Error != nil { + log.Println("Error find request record with trackID:", trackID) + return + } + record.Response = result + if db.Save(&record).Error != nil { + log.Println("Error to save record:", record) + return + } +}