record response and uuid
This commit is contained in:
1
go.mod
1
go.mod
@@ -17,6 +17,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -26,6 +26,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
|||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
|
|||||||
17
main.go
17
main.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -46,7 +47,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.AutoMigrate(&OPENAI_UPSTREAM{})
|
db.AutoMigrate(&OPENAI_UPSTREAM{})
|
||||||
db.AutoMigrate(&RequestRecord{})
|
db.AutoMigrate(&Record{})
|
||||||
log.Println("Auto migrate database done")
|
log.Println("Auto migrate database done")
|
||||||
|
|
||||||
if *addMode {
|
if *addMode {
|
||||||
@@ -96,6 +97,7 @@ func main() {
|
|||||||
db.Take(&authConfig, "key = ?", "authorization")
|
db.Take(&authConfig, "key = ?", "authorization")
|
||||||
|
|
||||||
engine.POST("/v1/*any", func(c *gin.Context) {
|
engine.POST("/v1/*any", func(c *gin.Context) {
|
||||||
|
trackID := uuid.New()
|
||||||
// check authorization header
|
// check authorization header
|
||||||
if !*noauth {
|
if !*noauth {
|
||||||
if handleAuth(c) != nil {
|
if handleAuth(c) != nil {
|
||||||
@@ -156,7 +158,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// record chat message from user
|
// record chat message from user
|
||||||
go recordUserMessage(c, db, body)
|
go recordUserMessage(c, db, trackID, body)
|
||||||
|
|
||||||
out.Body = io.NopCloser(bytes.NewReader(body))
|
out.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
|
||||||
@@ -169,6 +171,8 @@ func main() {
|
|||||||
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
out.Header.Set("Authorization", "Bearer "+upstream.SK)
|
||||||
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
out.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
}
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
var contentType string
|
||||||
proxy.ModifyResponse = func(r *http.Response) error {
|
proxy.ModifyResponse = func(r *http.Response) error {
|
||||||
if r.StatusCode != 200 {
|
if r.StatusCode != 200 {
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
@@ -182,6 +186,8 @@ func main() {
|
|||||||
"success_count": gorm.Expr("success_count + ?", 1),
|
"success_count": gorm.Expr("success_count + ?", 1),
|
||||||
"last_call_success_time": time.Now(),
|
"last_call_success_time": time.Now(),
|
||||||
})
|
})
|
||||||
|
r.Body = io.NopCloser(io.TeeReader(r.Body, &buf))
|
||||||
|
contentType = r.Header.Get("content-type")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
@@ -214,6 +220,11 @@ func main() {
|
|||||||
log.Println("response is", r.Response)
|
log.Println("response is", r.Response)
|
||||||
}
|
}
|
||||||
proxy.ServeHTTP(c.Writer, c.Request)
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
resp, err := io.ReadAll(io.NopCloser(&buf))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to read from response tee buffer", err)
|
||||||
|
}
|
||||||
|
go recordAssistantResponse(contentType, db, trackID, resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
// ---------------------------------
|
// ---------------------------------
|
||||||
@@ -311,7 +322,7 @@ func main() {
|
|||||||
if handleAuth(c) != nil {
|
if handleAuth(c) != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
requestRecords := []RequestRecord{}
|
requestRecords := []Record{}
|
||||||
err := db.Order("id desc").Limit(100).Find(&requestRecords).Error
|
err := db.Order("id desc").Limit(100).Find(&requestRecords).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(502, err)
|
c.AbortWithError(502, err)
|
||||||
|
|||||||
100
record.go
100
record.go
@@ -1,24 +1,118 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestRecord struct {
|
type Record struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
ID uuid.UUID `gorm:"type:uuid"`
|
||||||
Body string
|
Body string
|
||||||
|
Response string
|
||||||
}
|
}
|
||||||
|
|
||||||
func recordUserMessage(c *gin.Context, db *gorm.DB, body []byte) {
|
func recordUserMessage(c *gin.Context, db *gorm.DB, trackID uuid.UUID, body []byte) {
|
||||||
bodyStr := string(body)
|
bodyStr := string(body)
|
||||||
requestRecord := RequestRecord{
|
requestRecord := Record{
|
||||||
Body: bodyStr,
|
Body: bodyStr,
|
||||||
|
ID: trackID,
|
||||||
}
|
}
|
||||||
err := db.Create(&requestRecord).Error
|
err := db.Create(&requestRecord).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Error record request:", err)
|
log.Println("Error record request:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamModeChunk struct {
|
||||||
|
Choices []StreamModeChunkChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
type StreamModeChunkChoice struct {
|
||||||
|
Delta StreamModeDelta `json:"delta"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
type StreamModeDelta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FetchModeResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []FetchModeChoice `json:"choices"`
|
||||||
|
Usage FetchModeUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
type FetchModeChoice struct {
|
||||||
|
Message FetchModeMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
type FetchModeMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
type FetchModeUsage struct {
|
||||||
|
PromptTokens int64 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int64 `json:"completion_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordAssistantResponse(contentType string, db *gorm.DB, trackID uuid.UUID, body []byte) {
|
||||||
|
result := ""
|
||||||
|
// stream mode
|
||||||
|
if strings.HasPrefix(contentType, "text/event-stream") {
|
||||||
|
resp := string(body)
|
||||||
|
var chunk StreamModeChunk
|
||||||
|
for _, line := range strings.Split(resp, "\n") {
|
||||||
|
line = strings.TrimPrefix(line, "data:")
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(line), &chunk)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(chunk.Choices) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result += chunk.Choices[0].Delta.Content
|
||||||
|
log.Println(line)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(contentType, "application/json") {
|
||||||
|
var fetchResp FetchModeResponse
|
||||||
|
err := json.Unmarshal(body, &fetchResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error parsing fetch response:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(fetchResp.Model, "gpt-") {
|
||||||
|
log.Println("Not GPT model, skip recording response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(fetchResp.Choices) == 0 {
|
||||||
|
log.Println("Error: fetch response choice length is 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result = fetchResp.Choices[0].Message.Content
|
||||||
|
} else {
|
||||||
|
log.Println("Unknown content type", contentType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("Record result:", result)
|
||||||
|
record := Record{}
|
||||||
|
if db.Find(&record, "id = ?", trackID).Error != nil {
|
||||||
|
log.Println("Error find request record with trackID:", trackID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record.Response = result
|
||||||
|
if db.Save(&record).Error != nil {
|
||||||
|
log.Println("Error to save record:", record)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user