package main import ( "bytes" "encoding/json" "errors" "fmt" "io" "log" "net/http" "strings" "time" "github.com/gin-gonic/gin" ) var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions" func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { err := _processReplicateRequest(c, upstream, record, shouldResponse) if shouldResponse { if err != nil { c.AbortWithError(502, err) } } return err } func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { // read request body inBody, err := io.ReadAll(c.Request.Body) if err != nil { return errors.New("[processReplicateRequest]: failed to read request body " + err.Error()) } // record request body record.Body = string(inBody) // parse request body inRequest := &OpenAIChatRequest{} err = json.Unmarshal(inBody, inRequest) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error()) } record.Model = inRequest.Model // check allow model if len(upstream.Allow) > 0 { isAllow := false for _, model := range upstream.Allow { if model == inRequest.Model { isAllow = true break } } if !isAllow { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: model not allow") } } // check block model if len(upstream.Deny) > 0 { for _, model := range upstream.Deny { if model == inRequest.Model { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: model deny") } } } // set url model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model) log.Println("[processReplicateRequest]: model_url:", model_url) // create request with default value outRequest := &ReplicateModelRequest{ Stream: false, Input: ReplicateModelRequestInput{ Prompt: "", MaxNewTokens: 1024, Temperature: 0.6, Top_p: 0.9, Top_k: 50, FrequencyPenalty: 0.0, PresencePenalty: 0.0, PromptTemplate: "[INST] {prompt} [/INST] ", }, } // copy value from input request outRequest.Stream = inRequest.Stream outRequest.Input.Temperature = inRequest.Temperature outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty outRequest.Input.PresencePenalty = inRequest.PresencePenalty // render prompt for _, message := range inRequest.Messages { outRequest.Input.Prompt += message.Content + "\n" } // send request outBody, err := json.Marshal(outRequest) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error()) } // http add headers req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Token "+upstream.SK) // send resp, err := http.DefaultClient.Do(req) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to post request " + err.Error()) } // read response body outBody, err = io.ReadAll(resp.Body) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) } // parse reponse body outResponse := &ReplicateModelResponse{} err = json.Unmarshal(outBody, outResponse) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error()) } if outResponse.Stream { // get result log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream) req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to create get request " + err.Error()) } req.Header.Set("Authorization", "Token "+upstream.SK) req.Header.Set("Accept", "text/event-stream") // send resp, err := http.DefaultClient.Do(req) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to get request " + err.Error()) } // get result by chunk var buffer string = "" var indexCount int64 = 0 for { if !strings.Contains(buffer, "\n\n") { // receive chunk chunk := make([]byte, 1024) length, err := resp.Body.Read(chunk) if err == io.EOF { break } if length == 0 { break } if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) } // add chunk to buffer chunk = bytes.Trim(chunk, "\x00") buffer += string(chunk) continue } // cut the first chunk by find index index := strings.Index(buffer, "\n\n") chunk := buffer[:index] buffer = buffer[index+2:] // trim line chunk = strings.Trim(chunk, "\n") // ignore hi if !strings.Contains(chunk, "\n") { continue } // parse chunk to ReplicateModelResultChunk object chunkObj := &ReplicateModelResultChunk{} log.Println("[processReplicateRequest]: chunk:", chunk) lines := strings.Split(chunk, "\n") log.Println("[processReplicateRequest]: lines:", lines) // first line is event chunkObj.Event = strings.TrimSpace(lines[0]) chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ") fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event) fmt.Printf("Length: %d\n", len(chunkObj.Event)) // second line is id chunkObj.ID = strings.TrimSpace(lines[1]) chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ") chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0] // third line is data chunkObj.Data = lines[2] chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ") record.Response += chunkObj.Data // done if chunkObj.Event == "done" { break } // create OpenAI response chunk c.SSEvent("", &OpenAIChatResponse{ ID: chunkObj.Event, Model: outResponse.Model, Choices: []OpenAIChatResponseChoice{ { Index: indexCount, Message: OpenAIChatMessage{ Role: "assistant", Content: chunkObj.Data, }, }, }, }) c.Writer.Flush() indexCount += 1 } c.SSEvent("", &OpenAIChatResponse{ ID: "", Model: outResponse.Model, Choices: []OpenAIChatResponseChoice{ { Index: indexCount, Message: OpenAIChatMessage{ Role: "assistant", Content: "", }, FinishReason: "stop", }, }, }) c.Writer.Flush() indexCount += 1 record.Status = 200 return nil } else { var result *ReplicateModelResultGet for { // get result log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get) req, err := http.NewRequest("GET", outResponse.URLS.Get, nil) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to create get request " + err.Error()) } req.Header.Set("Authorization", "Token "+upstream.SK) // send resp, err := http.DefaultClient.Do(req) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to get request " + err.Error()) } // get result resultBody, err := io.ReadAll(resp.Body) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) } log.Println("[processReplicateRequest]: resultBody:", string(resultBody)) // parse reponse body result = &ReplicateModelResultGet{} err = json.Unmarshal(resultBody, result) if err != nil { c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error()) } log.Println("[processReplicateRequest]: result:", result) if result.Status == "processing" || result.Status == "starting" { time.Sleep(3 * time.Second) continue } break } // build openai resposne openAIResult := &OpenAIChatResponse{ ID: result.ID, Model: result.Model, Choices: []OpenAIChatResponseChoice{}, Usage: OpenAIChatResponseUsage{ TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount, PromptTokens: result.Metrics.InputTokenCount, }, } openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{ Index: 0, Message: OpenAIChatMessage{ Role: "assistant", Content: strings.Join(result.Output, ""), }, FinishReason: "stop", }) record.Body = strings.Join(result.Output, "") record.Status = 200 // gin return c.JSON(200, openAIResult) } return nil }