Compare commits

...

4 Commits

Author SHA1 Message Date
3385f9af08 fix: replicate mistral prompt 2024-01-23 16:38:45 +08:00
8fa7fa79be less: replicate log 2024-01-23 15:56:48 +08:00
49169452fe fix: read config upstream type default value 2024-01-23 15:52:46 +08:00
33f341026f fix: replicate response format 2024-01-23 15:43:48 +08:00
2 changed files with 69 additions and 22 deletions

View File

@@ -19,6 +19,7 @@ var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predi
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
err := _processReplicateRequest(c, upstream, record, shouldResponse) err := _processReplicateRequest(c, upstream, record, shouldResponse)
if shouldResponse { if shouldResponse {
sendCORSHeaders(c)
if err != nil { if err != nil {
c.AbortWithError(502, err) c.AbortWithError(502, err)
} }
@@ -85,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
Top_k: 50, Top_k: 50,
FrequencyPenalty: 0.0, FrequencyPenalty: 0.0,
PresencePenalty: 0.0, PresencePenalty: 0.0,
PromptTemplate: "<s>[INST] {prompt} [/INST] ", PromptTemplate: "{prompt}",
}, },
} }
@@ -96,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
outRequest.Input.PresencePenalty = inRequest.PresencePenalty outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt // render prompt
systemMessage := ""
userMessage := ""
assistantMessage := ""
for _, message := range inRequest.Messages { for _, message := range inRequest.Messages {
outRequest.Input.Prompt += message.Content + "\n" if message.Role == "system" {
if systemMessage != "" {
systemMessage += "\n"
} }
systemMessage += message.Content
continue
}
if message.Role == "user" {
if userMessage != "" {
userMessage += "\n"
}
userMessage += message.Content
if systemMessage != "" {
userMessage = systemMessage + "\n" + userMessage
systemMessage = ""
}
continue
}
if message.Role == "assistant" {
if assistantMessage != "" {
assistantMessage += "\n"
}
assistantMessage += message.Content
if outRequest.Input.Prompt != "" {
outRequest.Input.Prompt += "\n"
}
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
} else {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
userMessage = ""
assistantMessage = ""
}
// unknown role
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
}
// final user message
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
userMessage = ""
}
// final assistant message
if assistantMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
// send request // send request
outBody, err := json.Marshal(outRequest) outBody, err := json.Marshal(outRequest)
@@ -189,14 +239,10 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
// parse chunk to ReplicateModelResultChunk object // parse chunk to ReplicateModelResultChunk object
chunkObj := &ReplicateModelResultChunk{} chunkObj := &ReplicateModelResultChunk{}
log.Println("[processReplicateRequest]: chunk:", chunk)
lines := strings.Split(chunk, "\n") lines := strings.Split(chunk, "\n")
log.Println("[processReplicateRequest]: lines:", lines)
// first line is event // first line is event
chunkObj.Event = strings.TrimSpace(lines[0]) chunkObj.Event = strings.TrimSpace(lines[0])
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ") chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
fmt.Printf("Length: %d\n", len(chunkObj.Event))
// second line is id // second line is id
chunkObj.ID = strings.TrimSpace(lines[1]) chunkObj.ID = strings.TrimSpace(lines[1])
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ") chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
@@ -212,14 +258,16 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
break break
} }
sendCORSHeaders(c)
// create OpenAI response chunk // create OpenAI response chunk
c.SSEvent("", &OpenAIChatResponse{ c.SSEvent("", &OpenAIChatResponseChunk{
ID: chunkObj.Event, ID: "",
Model: outResponse.Model, Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{ Choices: []OpenAIChatResponseChunkChoice{
{ {
Index: indexCount, Index: indexCount,
Message: OpenAIChatMessage{ Delta: OpenAIChatMessage{
Role: "assistant", Role: "assistant",
Content: chunkObj.Data, Content: chunkObj.Data,
}, },
@@ -229,13 +277,14 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
c.Writer.Flush() c.Writer.Flush()
indexCount += 1 indexCount += 1
} }
c.SSEvent("", &OpenAIChatResponse{ sendCORSHeaders(c)
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "", ID: "",
Model: outResponse.Model, Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{ Choices: []OpenAIChatResponseChunkChoice{
{ {
Index: indexCount, Index: indexCount,
Message: OpenAIChatMessage{ Delta: OpenAIChatMessage{
Role: "assistant", Role: "assistant",
Content: "", Content: "",
}, },
@@ -273,8 +322,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
} }
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
// parse reponse body // parse reponse body
result = &ReplicateModelResultGet{} result = &ReplicateModelResultGet{}
err = json.Unmarshal(resultBody, result) err = json.Unmarshal(resultBody, result)
@@ -282,7 +329,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error()) return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
} }
log.Println("[processReplicateRequest]: result:", result)
if result.Status == "processing" || result.Status == "starting" { if result.Status == "processing" || result.Status == "starting" {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
@@ -315,6 +361,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
record.Status = 200 record.Status = 200
// gin return // gin return
sendCORSHeaders(c)
c.JSON(200, openAIResult) c.JSON(200, openAIResult)
} }

View File

@@ -62,11 +62,11 @@ func readConfig(filepath string) Config {
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err) log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
} }
config.Upstreams[i].URL = endpoint config.Upstreams[i].URL = endpoint
if upstream.Type == "" { if config.Upstreams[i].Type == "" {
upstream.Type = "openai" config.Upstreams[i].Type = "openai"
} }
if (upstream.Type != "openai") && (upstream.Type != "replicate") { if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
log.Fatalf("Unsupported upstream type '%s'", upstream.Type) log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
} }
} }
@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
type OpenAIChatRequestMessage struct { type OpenAIChatRequestMessage struct {
Content string `json:"content"` Content string `json:"content"`
User string `json:"user"` Role string `json:"role"`
} }
type ReplicateModelRequest struct { type ReplicateModelRequest struct {
@@ -176,5 +176,5 @@ type OpenAIChatResponseChunk struct {
type OpenAIChatResponseChunkChoice struct { type OpenAIChatResponseChunkChoice struct {
Index int64 `json:"index"` Index int64 `json:"index"`
Delta OpenAIChatMessage `json:"delta"` Delta OpenAIChatMessage `json:"delta"`
FinishReason *string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} }