fix: replicate mistral prompt
This commit is contained in:
53
replicate.go
53
replicate.go
@@ -86,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
Top_k: 50,
|
Top_k: 50,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
|
PromptTemplate: "{prompt}",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||||
|
|
||||||
// render prompt
|
// render prompt
|
||||||
|
systemMessage := ""
|
||||||
|
userMessage := ""
|
||||||
|
assistantMessage := ""
|
||||||
for _, message := range inRequest.Messages {
|
for _, message := range inRequest.Messages {
|
||||||
outRequest.Input.Prompt += message.Content + "\n"
|
if message.Role == "system" {
|
||||||
|
if systemMessage != "" {
|
||||||
|
systemMessage += "\n"
|
||||||
|
}
|
||||||
|
systemMessage += message.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
if userMessage != "" {
|
||||||
|
userMessage += "\n"
|
||||||
|
}
|
||||||
|
userMessage += message.Content
|
||||||
|
if systemMessage != "" {
|
||||||
|
userMessage = systemMessage + "\n" + userMessage
|
||||||
|
systemMessage = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "assistant" {
|
||||||
|
if assistantMessage != "" {
|
||||||
|
assistantMessage += "\n"
|
||||||
|
}
|
||||||
|
assistantMessage += message.Content
|
||||||
|
|
||||||
|
if outRequest.Input.Prompt != "" {
|
||||||
|
outRequest.Input.Prompt += "\n"
|
||||||
|
}
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
|
||||||
|
} else {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
userMessage = ""
|
||||||
|
assistantMessage = ""
|
||||||
|
}
|
||||||
|
// unknown role
|
||||||
|
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
|
||||||
}
|
}
|
||||||
|
// final user message
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
|
||||||
|
userMessage = ""
|
||||||
|
}
|
||||||
|
// final assistant message
|
||||||
|
if assistantMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
|
||||||
|
|
||||||
// send request
|
// send request
|
||||||
outBody, err := json.Marshal(outRequest)
|
outBody, err := json.Marshal(outRequest)
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
|
|||||||
|
|
||||||
type OpenAIChatRequestMessage struct {
|
type OpenAIChatRequestMessage struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
User string `json:"user"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReplicateModelRequest struct {
|
type ReplicateModelRequest struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user