Compare commits
4 Commits
b1a9d6b685
...
3385f9af08
| Author | SHA1 | Date | |
|---|---|---|---|
|
3385f9af08
|
|||
|
8fa7fa79be
|
|||
|
49169452fe
|
|||
|
33f341026f
|
79
replicate.go
79
replicate.go
@@ -19,6 +19,7 @@ var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predi
|
|||||||
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||||
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
||||||
if shouldResponse {
|
if shouldResponse {
|
||||||
|
sendCORSHeaders(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(502, err)
|
c.AbortWithError(502, err)
|
||||||
}
|
}
|
||||||
@@ -85,7 +86,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
Top_k: 50,
|
Top_k: 50,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
|
PromptTemplate: "{prompt}",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,9 +97,58 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||||
|
|
||||||
// render prompt
|
// render prompt
|
||||||
|
systemMessage := ""
|
||||||
|
userMessage := ""
|
||||||
|
assistantMessage := ""
|
||||||
for _, message := range inRequest.Messages {
|
for _, message := range inRequest.Messages {
|
||||||
outRequest.Input.Prompt += message.Content + "\n"
|
if message.Role == "system" {
|
||||||
|
if systemMessage != "" {
|
||||||
|
systemMessage += "\n"
|
||||||
|
}
|
||||||
|
systemMessage += message.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
if userMessage != "" {
|
||||||
|
userMessage += "\n"
|
||||||
|
}
|
||||||
|
userMessage += message.Content
|
||||||
|
if systemMessage != "" {
|
||||||
|
userMessage = systemMessage + "\n" + userMessage
|
||||||
|
systemMessage = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if message.Role == "assistant" {
|
||||||
|
if assistantMessage != "" {
|
||||||
|
assistantMessage += "\n"
|
||||||
|
}
|
||||||
|
assistantMessage += message.Content
|
||||||
|
|
||||||
|
if outRequest.Input.Prompt != "" {
|
||||||
|
outRequest.Input.Prompt += "\n"
|
||||||
|
}
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] %s </s>", userMessage, assistantMessage)
|
||||||
|
} else {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
userMessage = ""
|
||||||
|
assistantMessage = ""
|
||||||
|
}
|
||||||
|
// unknown role
|
||||||
|
log.Println("[processReplicateRequest]: Warning: unknown role", message.Role)
|
||||||
}
|
}
|
||||||
|
// final user message
|
||||||
|
if userMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> [INST] %s [/INST] ", userMessage)
|
||||||
|
userMessage = ""
|
||||||
|
}
|
||||||
|
// final assistant message
|
||||||
|
if assistantMessage != "" {
|
||||||
|
outRequest.Input.Prompt += fmt.Sprintf("<s> %s </s>", assistantMessage)
|
||||||
|
}
|
||||||
|
log.Println("[processReplicateRequest]: outRequest.Input.Prompt:", outRequest.Input.Prompt)
|
||||||
|
|
||||||
// send request
|
// send request
|
||||||
outBody, err := json.Marshal(outRequest)
|
outBody, err := json.Marshal(outRequest)
|
||||||
@@ -189,14 +239,10 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
|
|
||||||
// parse chunk to ReplicateModelResultChunk object
|
// parse chunk to ReplicateModelResultChunk object
|
||||||
chunkObj := &ReplicateModelResultChunk{}
|
chunkObj := &ReplicateModelResultChunk{}
|
||||||
log.Println("[processReplicateRequest]: chunk:", chunk)
|
|
||||||
lines := strings.Split(chunk, "\n")
|
lines := strings.Split(chunk, "\n")
|
||||||
log.Println("[processReplicateRequest]: lines:", lines)
|
|
||||||
// first line is event
|
// first line is event
|
||||||
chunkObj.Event = strings.TrimSpace(lines[0])
|
chunkObj.Event = strings.TrimSpace(lines[0])
|
||||||
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
||||||
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
|
|
||||||
fmt.Printf("Length: %d\n", len(chunkObj.Event))
|
|
||||||
// second line is id
|
// second line is id
|
||||||
chunkObj.ID = strings.TrimSpace(lines[1])
|
chunkObj.ID = strings.TrimSpace(lines[1])
|
||||||
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
||||||
@@ -212,14 +258,16 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sendCORSHeaders(c)
|
||||||
|
|
||||||
// create OpenAI response chunk
|
// create OpenAI response chunk
|
||||||
c.SSEvent("", &OpenAIChatResponse{
|
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||||
ID: chunkObj.Event,
|
ID: "",
|
||||||
Model: outResponse.Model,
|
Model: outResponse.Model,
|
||||||
Choices: []OpenAIChatResponseChoice{
|
Choices: []OpenAIChatResponseChunkChoice{
|
||||||
{
|
{
|
||||||
Index: indexCount,
|
Index: indexCount,
|
||||||
Message: OpenAIChatMessage{
|
Delta: OpenAIChatMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: chunkObj.Data,
|
Content: chunkObj.Data,
|
||||||
},
|
},
|
||||||
@@ -229,13 +277,14 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
indexCount += 1
|
indexCount += 1
|
||||||
}
|
}
|
||||||
c.SSEvent("", &OpenAIChatResponse{
|
sendCORSHeaders(c)
|
||||||
|
c.SSEvent("", &OpenAIChatResponseChunk{
|
||||||
ID: "",
|
ID: "",
|
||||||
Model: outResponse.Model,
|
Model: outResponse.Model,
|
||||||
Choices: []OpenAIChatResponseChoice{
|
Choices: []OpenAIChatResponseChunkChoice{
|
||||||
{
|
{
|
||||||
Index: indexCount,
|
Index: indexCount,
|
||||||
Message: OpenAIChatMessage{
|
Delta: OpenAIChatMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "",
|
Content: "",
|
||||||
},
|
},
|
||||||
@@ -273,8 +322,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
|
|
||||||
|
|
||||||
// parse reponse body
|
// parse reponse body
|
||||||
result = &ReplicateModelResultGet{}
|
result = &ReplicateModelResultGet{}
|
||||||
err = json.Unmarshal(resultBody, result)
|
err = json.Unmarshal(resultBody, result)
|
||||||
@@ -282,7 +329,6 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||||
}
|
}
|
||||||
log.Println("[processReplicateRequest]: result:", result)
|
|
||||||
|
|
||||||
if result.Status == "processing" || result.Status == "starting" {
|
if result.Status == "processing" || result.Status == "starting" {
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
@@ -315,6 +361,7 @@ func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record
|
|||||||
record.Status = 200
|
record.Status = 200
|
||||||
|
|
||||||
// gin return
|
// gin return
|
||||||
|
sendCORSHeaders(c)
|
||||||
c.JSON(200, openAIResult)
|
c.JSON(200, openAIResult)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
12
structure.go
12
structure.go
@@ -62,11 +62,11 @@ func readConfig(filepath string) Config {
|
|||||||
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
||||||
}
|
}
|
||||||
config.Upstreams[i].URL = endpoint
|
config.Upstreams[i].URL = endpoint
|
||||||
if upstream.Type == "" {
|
if config.Upstreams[i].Type == "" {
|
||||||
upstream.Type = "openai"
|
config.Upstreams[i].Type = "openai"
|
||||||
}
|
}
|
||||||
if (upstream.Type != "openai") && (upstream.Type != "replicate") {
|
if (config.Upstreams[i].Type != "openai") && (config.Upstreams[i].Type != "replicate") {
|
||||||
log.Fatalf("Unsupported upstream type '%s'", upstream.Type)
|
log.Fatalf("Unsupported upstream type '%s'", config.Upstreams[i].Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ type OpenAIChatRequest struct {
|
|||||||
|
|
||||||
type OpenAIChatRequestMessage struct {
|
type OpenAIChatRequestMessage struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
User string `json:"user"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReplicateModelRequest struct {
|
type ReplicateModelRequest struct {
|
||||||
@@ -176,5 +176,5 @@ type OpenAIChatResponseChunk struct {
|
|||||||
type OpenAIChatResponseChunkChoice struct {
|
type OpenAIChatResponseChunkChoice struct {
|
||||||
Index int64 `json:"index"`
|
Index int64 `json:"index"`
|
||||||
Delta OpenAIChatMessage `json:"delta"`
|
Delta OpenAIChatMessage `json:"delta"`
|
||||||
FinishReason *string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user