Files
openai-api-route/replicate.go
2024-01-23 15:20:22 +08:00

324 lines
9.0 KiB
Go

package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
err := _processReplicateRequest(c, upstream, record, shouldResponse)
if shouldResponse {
if err != nil {
c.AbortWithError(502, err)
}
}
return err
}
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
// read request body
inBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
}
// record request body
record.Body = string(inBody)
// parse request body
inRequest := &OpenAIChatRequest{}
err = json.Unmarshal(inBody, inRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
}
record.Model = inRequest.Model
// check allow model
if len(upstream.Allow) > 0 {
isAllow := false
for _, model := range upstream.Allow {
if model == inRequest.Model {
isAllow = true
break
}
}
if !isAllow {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model not allow")
}
}
// check block model
if len(upstream.Deny) > 0 {
for _, model := range upstream.Deny {
if model == inRequest.Model {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: model deny")
}
}
}
// set url
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
log.Println("[processReplicateRequest]: model_url:", model_url)
// create request with default value
outRequest := &ReplicateModelRequest{
Stream: false,
Input: ReplicateModelRequestInput{
Prompt: "",
MaxNewTokens: 1024,
Temperature: 0.6,
Top_p: 0.9,
Top_k: 50,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
},
}
// copy value from input request
outRequest.Stream = inRequest.Stream
outRequest.Input.Temperature = inRequest.Temperature
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
// render prompt
for _, message := range inRequest.Messages {
outRequest.Input.Prompt += message.Content + "\n"
}
// send request
outBody, err := json.Marshal(outRequest)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
}
// http add headers
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
}
// read response body
outBody, err = io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// parse reponse body
outResponse := &ReplicateModelResponse{}
err = json.Unmarshal(outBody, outResponse)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
if outResponse.Stream {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
req.Header.Set("Accept", "text/event-stream")
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result by chunk
var buffer string = ""
var indexCount int64 = 0
for {
if !strings.Contains(buffer, "\n\n") {
// receive chunk
chunk := make([]byte, 1024)
length, err := resp.Body.Read(chunk)
if err == io.EOF {
break
}
if length == 0 {
break
}
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
// add chunk to buffer
chunk = bytes.Trim(chunk, "\x00")
buffer += string(chunk)
continue
}
// cut the first chunk by find index
index := strings.Index(buffer, "\n\n")
chunk := buffer[:index]
buffer = buffer[index+2:]
// trim line
chunk = strings.Trim(chunk, "\n")
// ignore hi
if !strings.Contains(chunk, "\n") {
continue
}
// parse chunk to ReplicateModelResultChunk object
chunkObj := &ReplicateModelResultChunk{}
log.Println("[processReplicateRequest]: chunk:", chunk)
lines := strings.Split(chunk, "\n")
log.Println("[processReplicateRequest]: lines:", lines)
// first line is event
chunkObj.Event = strings.TrimSpace(lines[0])
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
fmt.Printf("Length: %d\n", len(chunkObj.Event))
// second line is id
chunkObj.ID = strings.TrimSpace(lines[1])
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
// third line is data
chunkObj.Data = lines[2]
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
record.Response += chunkObj.Data
// done
if chunkObj.Event == "done" {
break
}
// create OpenAI response chunk
c.SSEvent("", &OpenAIChatResponse{
ID: chunkObj.Event,
Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{
{
Index: indexCount,
Message: OpenAIChatMessage{
Role: "assistant",
Content: chunkObj.Data,
},
},
},
})
c.Writer.Flush()
indexCount += 1
}
c.SSEvent("", &OpenAIChatResponse{
ID: "",
Model: outResponse.Model,
Choices: []OpenAIChatResponseChoice{
{
Index: indexCount,
Message: OpenAIChatMessage{
Role: "assistant",
Content: "",
},
FinishReason: "stop",
},
},
})
c.Writer.Flush()
indexCount += 1
record.Status = 200
return nil
} else {
var result *ReplicateModelResultGet
for {
// get result
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
}
req.Header.Set("Authorization", "Token "+upstream.SK)
// send
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
}
// get result
resultBody, err := io.ReadAll(resp.Body)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
}
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
// parse reponse body
result = &ReplicateModelResultGet{}
err = json.Unmarshal(resultBody, result)
if err != nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
}
log.Println("[processReplicateRequest]: result:", result)
if result.Status == "processing" || result.Status == "starting" {
time.Sleep(3 * time.Second)
continue
}
break
}
// build openai resposne
openAIResult := &OpenAIChatResponse{
ID: result.ID,
Model: result.Model,
Choices: []OpenAIChatResponseChoice{},
Usage: OpenAIChatResponseUsage{
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
PromptTokens: result.Metrics.InputTokenCount,
},
}
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
Index: 0,
Message: OpenAIChatMessage{
Role: "assistant",
Content: strings.Join(result.Output, ""),
},
FinishReason: "stop",
})
record.Body = strings.Join(result.Output, "")
record.Status = 200
// gin return
c.JSON(200, openAIResult)
}
return nil
}