Compare commits
4 Commits
b1a9d6b685
...
3385f9af08
| Author | SHA1 | Date | |
|---|---|---|---|
|
3385f9af08
|
|||
|
8fa7fa79be
|
|||
|
49169452fe
|
|||
|
33f341026f
|
79
replicate.go
79
replicate.go
@@ -19,6 +19,7 @@ var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predi
|
||||
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
||||
if shouldResponse {
|
||||
sendCORSHeaders(c)
|
||||
if err != nil {
|
||||
c.AbortWithError(502, err)
|
||||
}
|
||||
@@ -85,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
Top_k: 50,
|
||||
FrequencyPenalty: 0.0,
|
||||
PresencePenalty: 0.0,
|
||||
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
|
||||
PromptTemplate: "{prompt}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -96,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||
|
||||
// render prompt
|
||||
systemMessage := ""
|
||||
userMessage := ""
|
||||
assistantMessage := ""
|
||||
for _, message := range inRequest.Messages {
|
||||
outRequest.Input.Prompt += message.Content + "\n"
|
||||
if message.Role == "system" {
|
||||
if systemMessage != "" {
|
||||
systemMessage += "\n"
|
||||
}
|
||||
systemMessage += message.Content
|
||||
continue
|
||||
}
|
||||
if message.Role == "user" {
|
||||
if userMessage != "" {
|
||||
userMessage += "\n"
|
||||
}
|
||||
userMessage += message.Content
|
||||
if systemMessage != "" {
|
||||
userMessage = systemMessage + "\n" + userMessage
|
||||
systemMessage = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
if message.Role == "assistant" {
|
||||
if assistantMessage != "" {
|
||||
assistantMessage += "\n"
|
||||
}
|
||||
assistantMessage += message.Content
|
||||
|
||||
if outRequest.Input.Prompt != "" {
|
||||
outRequest.Input.Prompt += "\n"
|
||||
}
|
||||
if userMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
|
||||
} else {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||
}
|
||||
userMessage = ""
|
||||
assistantMessage = ""
|
||||
}
|
||||
// unknown role
|
||||
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
|
||||
}
|
||||
// final user message
|
||||
if userMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
|
||||
userMessage = ""
|
||||
}
|
||||
// final assistant message
|
||||
if assistantMessage != "" {
|
||||
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||
}
|
||||
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
|
||||
|
||||
// send request
|
||||
outBody, err := json.Marshal(outRequest)
|
||||
@@ -189,14 +239,10 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
|
||||
// parse chunk to ReplicateModelResultChunk object
|
||||
chunkObj := &ReplicateModelResultChunk{}
|
||||
log.Println("[processReplicateRequest]: chunk:", chunk)
|
||||
lines := strings.Split(chunk, "\n")
|
||||
log.Println("[processReplicateRequest]: lines:", lines)
|
||||
// first line is event
|
||||
chunkObj.Event = strings.TrimSpace(lines[0])
|
||||
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
||||
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
|
||||
fmt.Printf("Length: %d\n", len(chunkObj.Event))
|
||||
// second line is id
|
||||
chunkObj.ID = strings.TrimSpace(lines[1])
|
||||
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
||||
@@ -212,14 +258,16 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
break
|
||||
}
|
||||
|
||||
sendCORSHeaders(c)
|
||||
|
||||
// create OpenAI response chunk
|
||||
c.SSEvent("", &OpenAIChatResponse{
|
||||
ID: chunkObj.Event,
|
||||
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||
ID: "",
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChoice{
|
||||
Choices: []OpenAIChatResponseChunkChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Message: OpenAIChatMessage{
|
||||
Delta: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: chunkObj.Data,
|
||||
},
|
||||
@@ -229,13 +277,14 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
c.Writer.Flush()
|
||||
indexCount += 1
|
||||
}
|
||||
c.SSEvent("", &OpenAIChatResponse{
|
||||
sendCORSHeaders(c)
|
||||
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||
ID: "",
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChoice{
|
||||
Choices: []OpenAIChatResponseChunkChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Message: OpenAIChatMessage{
|
||||
Delta: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
@@ -273,8 +322,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
|
||||
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
|
||||
|
||||
// parse reponse body
|
||||
result = &ReplicateModelResultGet{}
|
||||
err = json.Unmarshal(resultBody, result)
|
||||
@@ -282,7 +329,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||
}
|
||||
log.Println("[processReplicateRequest]: result:", result)
|
||||
|
||||
if result.Status == "processing" || result.Status == "starting" {
|
||||
time.Sleep(3 * time.Second)
|
||||
@@ -315,6 +361,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
||||
record.Status = 200
|
||||
|
||||
// gin return
|
||||
sendCORSHeaders(c)
|
||||
c.JSON(200, openAIResult)
|
||||
|
||||
}
|
||||
|
||||
12
structure.go
12
structure.go
@@ -62,11 +62,11 @@ func readConfig(filepath string) Config {
|
||||
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
||||
}
|
||||
config.Upstreams[i].URL = endpoint
|
||||
if upstream.Type == "" {
|
||||
upstream.Type = "openai"
|
||||
if config.Upstreams[i].Type == "" {
|
||||
config.Upstreams[i].Type = "openai"
|
||||
}
|
||||
if (upstream.Type != "openai") && (upstream.Type != "replicate") {
|
||||
log.Fatalf("Unsupported upstream type '%s'", upstream.Type)
|
||||
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
|
||||
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
|
||||
|
||||
type OpenAIChatRequestMessage struct {
|
||||
Content string `json:"content"`
|
||||
User string `json:"user"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type ReplicateModelRequest struct {
|
||||
@@ -176,5 +176,5 @@ type OpenAIChatResponseChunk struct {
|
||||
type OpenAIChatResponseChunkChoice struct {
|
||||
Index int64 `json:"index"`
|
||||
Delta OpenAIChatMessage `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user