From 3385f9af08340295adb9744c9853c2f242c3038a Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 23 Jan 2024 16:38:45 +0800 Subject: [PATCH] fix: replicate mistral prompt --- replicate.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++-- structure.go | 2 +- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/replicate.go b/replicate.go index 6bc547b..a09f318 100644 --- a/replicate.go +++ b/replicate.go @@ -86,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record Top_k: 50, FrequencyPenalty: 0.0, PresencePenalty: 0.0, - PromptTemplate: "[INST] {prompt} [/INST] ", + PromptTemplate: "{prompt}", }, } @@ -97,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record outRequest.Input.PresencePenalty = inRequest.PresencePenalty // render prompt + systemMessage := "" + userMessage := "" + assistantMessage := "" for _, message := range inRequest.Messages { - outRequest.Input.Prompt += message.Content + "\n" + if message.Role == "system" { + if systemMessage != "" { + systemMessage += "\n" + } + systemMessage += message.Content + continue + } + if message.Role == "user" { + if userMessage != "" { + userMessage += "\n" + } + userMessage += message.Content + if systemMessage != "" { + userMessage = systemMessage + "\n" + userMessage + systemMessage = "" + } + continue + } + if message.Role == "assistant" { + if assistantMessage != "" { + assistantMessage += "\n" + } + assistantMessage += message.Content + + if outRequest.Input.Prompt != "" { + outRequest.Input.Prompt += "\n" + } + if userMessage != "" { + outRequest.Input.Prompt += fmt.Sprintf(" [INST] %s [/INST] %s ", userMessage, assistantMessage) + } else { + outRequest.Input.Prompt += fmt.Sprintf(" %s ", assistantMessage) + } + userMessage = "" + assistantMessage = "" + } + // unknown role + log.Println("[processReplicateRequest]: Warning: unknown role", message.Role) } + // final user message + if userMessage != "" { + outRequest.Input.Prompt += fmt.Sprintf(" [INST] %s [/INST] ", userMessage) + userMessage = "" + } + // final assistant message + if assistantMessage != "" { + outRequest.Input.Prompt += fmt.Sprintf(" %s ", assistantMessage) + } + log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt) // send request outBody, err := json.Marshal(outRequest) diff --git a/structure.go b/structure.go index 37cc721..e7ac3c3 100644 --- a/structure.go +++ b/structure.go @@ -85,7 +85,7 @@ type OpenAIChatRequest struct { type OpenAIChatRequestMessage struct { Content string `json:"content"` - User string `json:"user"` + Role string `json:"role"` } type ReplicateModelRequest struct {