fix: replicate mistral prompt

This commit is contained in:
2024-01-23 16:38:45 +08:00
parent 8fa7fa79be
commit 3385f9af08
2 changed files with 52 additions and 3 deletions

View File

@@ -86,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
Top_k: 50, Top_k: 50,
FrequencyPenalty: 0.0, FrequencyPenalty: 0.0,
PresencePenalty: 0.0, PresencePenalty: 0.0,
PromptTemplate: "<s>[INST] {prompt} [/INST] ", PromptTemplate: "{prompt}",
}, },
} }
@@ -97,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
outRequest.Input.PresencePenalty = inRequest.PresencePenalty outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt // render prompt
systemMessage := ""
userMessage := ""
assistantMessage := ""
for _, message := range inRequest.Messages { for _, message := range inRequest.Messages {
outRequest.Input.Prompt += message.Content + "\n" if message.Role == "system" {
if systemMessage != "" {
systemMessage += "\n"
} }
systemMessage += message.Content
continue
}
if message.Role == "user" {
if userMessage != "" {
userMessage += "\n"
}
userMessage += message.Content
if systemMessage != "" {
userMessage = systemMessage + "\n" + userMessage
systemMessage = ""
}
continue
}
if message.Role == "assistant" {
if assistantMessage != "" {
assistantMessage += "\n"
}
assistantMessage += message.Content
if outRequest.Input.Prompt != "" {
outRequest.Input.Prompt += "\n"
}
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
} else {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
userMessage = ""
assistantMessage = ""
}
// unknown role
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
}
// final user message
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
userMessage = ""
}
// final assistant message
if assistantMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
// send request // send request
outBody, err := json.Marshal(outRequest) outBody, err := json.Marshal(outRequest)

View File

@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
type OpenAIChatRequestMessage struct { type OpenAIChatRequestMessage struct {
Content string `json:"content"` Content string `json:"content"`
User string `json:"user"` Role string `json:"role"`
} }
type ReplicateModelRequest struct { type ReplicateModelRequest struct {