Compare commits

...

4 Commits

Author SHA1 Message Date
3385f9af08 fix: replicate mistral prompt 2024-01-23 16:38:45 +08:00
8fa7fa79be less: replicate log 2024-01-23 15:56:48 +08:00
49169452fe fix: read config upstream type default value 2024-01-23 15:52:46 +08:00
33f341026f fix: replicate response format 2024-01-23 15:43:48 +08:00
2 changed files with 69 additions and 22 deletions

View File

@@ -19,6 +19,7 @@ var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predi
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
err := _processReplicateRequest(c, upstream, record, shouldResponse)
if shouldResponse {
sendCORSHeaders(c)
if err != nil {
c.AbortWithError(502, err)
}
@@ -85,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
Top_k: 50,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
PromptTemplate: "{prompt}",
},
}
@@ -96,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt
systemMessage := ""
userMessage := ""
assistantMessage := ""
for _, message := range inRequest.Messages {
outRequest.Input.Prompt += message.Content + "\n"
if message.Role == "system" {
if systemMessage != "" {
systemMessage += "\n"
}
systemMessage += message.Content
continue
}
if message.Role == "user" {
if userMessage != "" {
userMessage += "\n"
}
userMessage += message.Content
if systemMessage != "" {
userMessage = systemMessage + "\n" + userMessage
systemMessage = ""
}
continue
}
if message.Role == "assistant" {
if assistantMessage != "" {
assistantMessage += "\n"
}
assistantMessage += message.Content
if outRequest.Input.Prompt != "" {
outRequest.Input.Prompt += "\n"
}
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
} else {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
userMessage = ""
assistantMessage = ""
}
// unknown role
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
}
// final user message
if userMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
userMessage = ""
}
// final assistant message
if assistantMessage != "" {
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
}
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
// send request
outBody, err := json.Marshal(outRequest)
@@ -189,14 +239,10 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
// parse chunk to ReplicateModelResultChunk object
chunkObj := &ReplicateModelResultChunk{}
log.Println("[processReplicateRequest]: chunk:", chunk)
lines := strings.Split(chunk, "\n")
log.Println("[processReplicateRequest]: lines:", lines)
// first line is event
chunkObj.Event = strings.TrimSpace(lines[0])
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
fmt.Printf("Length: %d\n", len(chunkObj.Event))
// second line is id
chunkObj.ID = strings.TrimSpace(lines[1])
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
@@ -212,14 +258,16 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
break
}
sendCORSHeaders(c)
// create OpenAI response chunk
c.SSEvent("", &OpenAIChatResponse{
ID: chunkObj.Event,
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Message: OpenAIChatMessage{
Delta: OpenAIChatMessage{
Role: "assistant",
Content: chunkObj.Data,
},
@@ -229,13 +277,14 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
c.Writer.Flush()
indexCount += 1
}
c.SSEvent("", &OpenAIChatResponse{
sendCORSHeaders(c)
c.SSEvent("", &OpenAIChatResponseChunk{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{
Choices: []OpenAIChatResponseChunkChoice{
{
Index: indexCount,
Message: OpenAIChatMessage{
Delta: OpenAIChatMessage{
Role: "assistant",
Content: "",
},
@@ -273,8 +322,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
// parse reponse body
result = &ReplicateModelResultGet{}
err = json.Unmarshal(resultBody, result)
@@ -282,7 +329,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
log.Println("[processReplicateRequest]: result:", result)
if result.Status == "processing" || result.Status == "starting" {
time.Sleep(3 * time.Second)
@@ -315,6 +361,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
record.Status = 200
// gin return
sendCORSHeaders(c)
c.JSON(200, openAIResult)
}

View File

@@ -62,11 +62,11 @@ func readConfig(filepath string) Config {
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
}
config.Upstreams[i].URL = endpoint
if upstream.Type == "" {
upstream.Type = "openai"
if config.Upstreams[i].Type == "" {
config.Upstreams[i].Type = "openai"
}
if (upstream.Type != "openai") && (upstream.Type != "replicate") {
log.Fatalf("Unsupported upstream type '%s'", upstream.Type)
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
}
}
@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
type OpenAIChatRequestMessage struct {
Content string `json:"content"`
User string `json:"user"`
Role string `json:"role"`
}
type ReplicateModelRequest struct {
@@ -176,5 +176,5 @@ type OpenAIChatResponseChunk struct {
type OpenAIChatResponseChunkChoice struct {
Index int64 `json:"index"`
Delta OpenAIChatMessage `json:"delta"`
FinishReason *string `json:"finish_reason"`
FinishReason string `json:"finish_reason"`
}