add: support replicate
This commit is contained in:
10
README.md
10
README.md
@@ -11,6 +11,7 @@
|
||||
- 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。具体超时策略请参阅 [超时策略](#超时策略) 一节
|
||||
- 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本
|
||||
- 请求出错时发送 飞书 或 Matrix 消息通知
|
||||
- 支持 Replicate 平台上的模型
|
||||
|
||||
本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。
|
||||
|
||||
@@ -98,6 +99,9 @@ dbaddr: ./db.sqlite
|
||||
# dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima"
|
||||
|
||||
upstreams:
|
||||
- sk: "key_for_replicate"
|
||||
type: replicate
|
||||
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]
|
||||
- sk: "secret_key_1"
|
||||
endpoint: "https://api.openai.com/v2"
|
||||
- sk: "secret_key_2"
|
||||
@@ -109,6 +113,12 @@ upstreams:
|
||||
|
||||
您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。
|
||||
|
||||
## 模型允许与屏蔽列表
|
||||
|
||||
如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。
|
||||
|
||||
如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台
|
||||
|
||||
## 超时策略
|
||||
|
||||
在处理上游请求时,超时策略是确保服务稳定性和响应性的关键因素。本服务通过配置文件中的 `Upstreams` 部分来定义多个上游服务器。每个上游服务器都有自己的 `Endpoint` 和 `SK`(可能是密钥或特殊标识)。服务会按照配置文件中的顺序依次尝试每个上游服务器,直到请求成功或所有上游服务器都已尝试。
|
||||
|
||||
@@ -11,10 +11,13 @@ dbaddr: ./db.sqlite
|
||||
upstreams:
|
||||
- sk: "secret_key_1"
|
||||
endpoint: "https://api.openai.com/v2"
|
||||
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||
- sk: "secret_key_2"
|
||||
endpoint: "https://api.openai.com/v1"
|
||||
timeout: 30
|
||||
allow: ["gpt-3.5-trubo"] # 可选的模型白名单
|
||||
deny: ["gpt-4"] # 可选的模型黑名单
|
||||
# 若白名单和黑名单同时设置,先判断白名单,再判断黑名单
|
||||
- sk: "key_for_replicate"
|
||||
type: replicate
|
||||
allow: ["mistralai/mixtral-8x7b-instruct-v0.1"]
|
||||
|
||||
13
main.go
13
main.go
@@ -123,9 +123,9 @@ func main() {
|
||||
}
|
||||
|
||||
for index, upstream := range config.Upstreams {
|
||||
if upstream.Endpoint == "" || upstream.SK == "" {
|
||||
if upstream.SK == "" {
|
||||
sendCORSHeaders(c)
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint))
|
||||
c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) '%s'", upstream.SK))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -135,7 +135,14 @@ func main() {
|
||||
upstream.Timeout = 120
|
||||
}
|
||||
|
||||
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||
if upstream.Type == "replicate" {
|
||||
err = processReplicateRequest(c, &upstream, &record, shouldResponse)
|
||||
} else if upstream.Type == "openai" {
|
||||
err = processRequest(c, &upstream, &record, shouldResponse)
|
||||
} else {
|
||||
err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == http.ErrAbortHandler {
|
||||
abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here"
|
||||
|
||||
323
replicate.go
Normal file
323
replicate.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions"
|
||||
|
||||
func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
err := _processReplicateRequest(c, upstream, record, shouldResponse)
|
||||
if shouldResponse {
|
||||
if err != nil {
|
||||
c.AbortWithError(502, err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error {
|
||||
// read request body
|
||||
inBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return errors.New("[processReplicateRequest]: failed to read request body " + err.Error())
|
||||
}
|
||||
|
||||
// record request body
|
||||
record.Body = string(inBody)
|
||||
|
||||
// parse request body
|
||||
inRequest := &OpenAIChatRequest{}
|
||||
err = json.Unmarshal(inBody, inRequest)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error())
|
||||
}
|
||||
|
||||
record.Model = inRequest.Model
|
||||
|
||||
// check allow model
|
||||
if len(upstream.Allow) > 0 {
|
||||
isAllow := false
|
||||
for _, model := range upstream.Allow {
|
||||
if model == inRequest.Model {
|
||||
isAllow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isAllow {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: model not allow")
|
||||
}
|
||||
}
|
||||
// check block model
|
||||
if len(upstream.Deny) > 0 {
|
||||
for _, model := range upstream.Deny {
|
||||
if model == inRequest.Model {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: model deny")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set url
|
||||
model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model)
|
||||
log.Println("[processReplicateRequest]: model_url:", model_url)
|
||||
|
||||
// create request with default value
|
||||
outRequest := &ReplicateModelRequest{
|
||||
Stream: false,
|
||||
Input: ReplicateModelRequestInput{
|
||||
Prompt: "",
|
||||
MaxNewTokens: 1024,
|
||||
Temperature: 0.6,
|
||||
Top_p: 0.9,
|
||||
Top_k: 50,
|
||||
FrequencyPenalty: 0.0,
|
||||
PresencePenalty: 0.0,
|
||||
PromptTemplate: "<s>[INST] {prompt} [/INST] ",
|
||||
},
|
||||
}
|
||||
|
||||
// copy value from input request
|
||||
outRequest.Stream = inRequest.Stream
|
||||
outRequest.Input.Temperature = inRequest.Temperature
|
||||
outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty
|
||||
outRequest.Input.PresencePenalty = inRequest.PresencePenalty
|
||||
|
||||
// render prompt
|
||||
for _, message := range inRequest.Messages {
|
||||
outRequest.Input.Prompt += message.Content + "\n"
|
||||
}
|
||||
|
||||
// send request
|
||||
outBody, err := json.Marshal(outRequest)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error())
|
||||
}
|
||||
|
||||
// http add headers
|
||||
req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to post request " + err.Error())
|
||||
}
|
||||
|
||||
// read response body
|
||||
outBody, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
|
||||
// parse reponse body
|
||||
outResponse := &ReplicateModelResponse{}
|
||||
err = json.Unmarshal(outBody, outResponse)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||
}
|
||||
|
||||
if outResponse.Stream {
|
||||
// get result
|
||||
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream)
|
||||
req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||
}
|
||||
|
||||
// get result by chunk
|
||||
var buffer string = ""
|
||||
var indexCount int64 = 0
|
||||
for {
|
||||
if !strings.Contains(buffer, "\n\n") {
|
||||
// receive chunk
|
||||
chunk := make([]byte, 1024)
|
||||
length, err := resp.Body.Read(chunk)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
// add chunk to buffer
|
||||
chunk = bytes.Trim(chunk, "\x00")
|
||||
buffer += string(chunk)
|
||||
continue
|
||||
}
|
||||
|
||||
// cut the first chunk by find index
|
||||
index := strings.Index(buffer, "\n\n")
|
||||
chunk := buffer[:index]
|
||||
buffer = buffer[index+2:]
|
||||
|
||||
// trim line
|
||||
chunk = strings.Trim(chunk, "\n")
|
||||
|
||||
// ignore hi
|
||||
if !strings.Contains(chunk, "\n") {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse chunk to ReplicateModelResultChunk object
|
||||
chunkObj := &ReplicateModelResultChunk{}
|
||||
log.Println("[processReplicateRequest]: chunk:", chunk)
|
||||
lines := strings.Split(chunk, "\n")
|
||||
log.Println("[processReplicateRequest]: lines:", lines)
|
||||
// first line is event
|
||||
chunkObj.Event = strings.TrimSpace(lines[0])
|
||||
chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ")
|
||||
fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event)
|
||||
fmt.Printf("Length: %d\n", len(chunkObj.Event))
|
||||
// second line is id
|
||||
chunkObj.ID = strings.TrimSpace(lines[1])
|
||||
chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ")
|
||||
chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0]
|
||||
// third line is data
|
||||
chunkObj.Data = lines[2]
|
||||
chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ")
|
||||
|
||||
record.Response += chunkObj.Data
|
||||
|
||||
// done
|
||||
if chunkObj.Event == "done" {
|
||||
break
|
||||
}
|
||||
|
||||
// create OpenAI response chunk
|
||||
c.SSEvent("", &OpenAIChatResponse{
|
||||
ID: chunkObj.Event,
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Message: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: chunkObj.Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Writer.Flush()
|
||||
indexCount += 1
|
||||
}
|
||||
c.SSEvent("", &OpenAIChatResponse{
|
||||
ID: "",
|
||||
Model: outResponse.Model,
|
||||
Choices: []OpenAIChatResponseChoice{
|
||||
{
|
||||
Index: indexCount,
|
||||
Message: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Writer.Flush()
|
||||
indexCount += 1
|
||||
record.Status = 200
|
||||
return nil
|
||||
|
||||
} else {
|
||||
var result *ReplicateModelResultGet
|
||||
|
||||
for {
|
||||
// get result
|
||||
log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get)
|
||||
req, err := http.NewRequest("GET", outResponse.URLS.Get, nil)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to create get request " + err.Error())
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+upstream.SK)
|
||||
// send
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to get request " + err.Error())
|
||||
}
|
||||
// get result
|
||||
resultBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to read response body " + err.Error())
|
||||
}
|
||||
|
||||
log.Println("[processReplicateRequest]: resultBody:", string(resultBody))
|
||||
|
||||
// parse reponse body
|
||||
result = &ReplicateModelResultGet{}
|
||||
err = json.Unmarshal(resultBody, result)
|
||||
if err != nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody))
|
||||
return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error())
|
||||
}
|
||||
log.Println("[processReplicateRequest]: result:", result)
|
||||
|
||||
if result.Status == "processing" || result.Status == "starting" {
|
||||
time.Sleep(3 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// build openai resposne
|
||||
openAIResult := &OpenAIChatResponse{
|
||||
ID: result.ID,
|
||||
Model: result.Model,
|
||||
Choices: []OpenAIChatResponseChoice{},
|
||||
Usage: OpenAIChatResponseUsage{
|
||||
TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount,
|
||||
PromptTokens: result.Metrics.InputTokenCount,
|
||||
},
|
||||
}
|
||||
openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{
|
||||
Index: 0,
|
||||
Message: OpenAIChatMessage{
|
||||
Role: "assistant",
|
||||
Content: strings.Join(result.Output, ""),
|
||||
},
|
||||
FinishReason: "stop",
|
||||
})
|
||||
|
||||
record.Body = strings.Join(result.Output, "")
|
||||
record.Status = 200
|
||||
|
||||
// gin return
|
||||
c.JSON(200, openAIResult)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
113
structure.go
113
structure.go
@@ -22,6 +22,7 @@ type OPENAI_UPSTREAM struct {
|
||||
Timeout int64 `yaml:"timeout"`
|
||||
Allow []string `yaml:"allow"`
|
||||
Deny []string `yaml:"deny"`
|
||||
Type string `yaml:"type"`
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
@@ -61,7 +62,119 @@ func readConfig(filepath string) Config {
|
||||
log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err)
|
||||
}
|
||||
config.Upstreams[i].URL = endpoint
|
||||
if upstream.Type == "" {
|
||||
upstream.Type = "openai"
|
||||
}
|
||||
if (upstream.Type != "openai") && (upstream.Type != "replicate") {
|
||||
log.Fatalf("Unsupported upstream type '%s'", upstream.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
type OpenAIChatRequest struct {
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Messages []OpenAIChatRequestMessage
|
||||
}
|
||||
|
||||
type OpenAIChatRequestMessage struct {
|
||||
Content string `json:"content"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
|
||||
type ReplicateModelRequest struct {
|
||||
Stream bool `json:"stream"`
|
||||
Input ReplicateModelRequestInput `json:"input"`
|
||||
}
|
||||
|
||||
type ReplicateModelRequestInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
MaxNewTokens int64 `json:"max_new_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Top_p float64 `json:"top_p"`
|
||||
Top_k int64 `json:"top_k"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PromptTemplate string `json:"prompt_template"`
|
||||
}
|
||||
|
||||
type ReplicateModelResponse struct {
|
||||
Model string `json:"model"`
|
||||
Version string `json:"version"`
|
||||
Stream bool `json:"stream"`
|
||||
Error string `json:"error"`
|
||||
URLS ReplicateModelResponseURLS `json:"urls"`
|
||||
}
|
||||
|
||||
type ReplicateModelResponseURLS struct {
|
||||
Cancel string `json:"cancel"`
|
||||
Get string `json:"get"`
|
||||
Stream string `json:"stream"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultGet struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Version string `json:"version"`
|
||||
Output []string `json:"output"`
|
||||
Error string `json:"error"`
|
||||
Metrics ReplicateModelResultMetrics `json:"metrics"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultMetrics struct {
|
||||
InputTokenCount int64 `json:"input_token_count"`
|
||||
OutputTokenCount int64 `json:"output_token_count"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChatResponseChoice `json:"choices"`
|
||||
Usage OpenAIChatResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseUsage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChoice struct {
|
||||
Index int64 `json:"index"`
|
||||
Message OpenAIChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAIChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ReplicateModelResultChunk struct {
|
||||
Event string `json:"event"`
|
||||
ID string `json:"id"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChatResponseChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type OpenAIChatResponseChunkChoice struct {
|
||||
Index int64 `json:"index"`
|
||||
Delta OpenAIChatMessage `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user