diff --git a/README.md b/README.md index f1ea269..16913d4 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ - 识别 ChatCompletions Stream 请求,针对 Stream 请求使用 5 秒超时。具体超时策略请参阅 [超时策略](#超时策略) 一节 - 记录完整的请求内容、使用的上游、IP 地址、响应时间以及 GPT 回复文本 - 请求出错时发送 飞书 或 Matrix 消息通知 +- 支持 Replicate 平台上的模型 本文档详细介绍了如何使用负载均衡和能力 API 的方法和端点。 @@ -98,6 +99,9 @@ dbaddr: ./db.sqlite # dbaddr: "host=127.0.0.1 port=5432 user=postgres dbname=openai_api_route sslmode=disable password=woshimima" upstreams: + - sk: "key_for_replicate" + type: replicate + allow: ["mistralai/mixtral-8x7b-instruct-v0.1"] - sk: "secret_key_1" endpoint: "https://api.openai.com/v2" - sk: "secret_key_2" @@ -109,6 +113,12 @@ upstreams: 您可以直接运行 `./openai-api-route` 命令,如果数据库不存在,系统会自动创建。 +## 模型允许与屏蔽列表 + +如果对某个上游设置了 allow 或 deny 列表,则负载均衡只允许或禁用用户使用这些模型。负载均衡程序会先判断白名单,再判断黑名单。 + +如果你混合使用 OpenAI 和 Replicate 平台的模型,你可能需要分别为 OpenAI 和 Replicate 上游设置他们各自的允许列表,否则用户请求 OpenAI 的模型时可能会发送到 Replicate 平台 + ## 超时策略 在处理上游请求时,超时策略是确保服务稳定性和响应性的关键因素。本服务通过配置文件中的 `Upstreams` 部分来定义多个上游服务器。每个上游服务器都有自己的 `Endpoint` 和 `SK`(可能是密钥或特殊标识)。服务会按照配置文件中的顺序依次尝试每个上游服务器,直到请求成功或所有上游服务器都已尝试。 diff --git a/config.sample.yaml b/config.sample.yaml index fac7a98..7e598c1 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -11,10 +11,13 @@ dbaddr: ./db.sqlite upstreams: - sk: "secret_key_1" endpoint: "https://api.openai.com/v2" - allow: ["gpt-3.5-trubo"] # 可选的模型白名单 + allow: ["gpt-3.5-trubo"] # 可选的模型白名单 - sk: "secret_key_2" endpoint: "https://api.openai.com/v1" timeout: 30 allow: ["gpt-3.5-trubo"] # 可选的模型白名单 deny: ["gpt-4"] # 可选的模型黑名单 - # 若白名单和黑名单同时设置,先判断白名单,再判断黑名单 \ No newline at end of file + # 若白名单和黑名单同时设置,先判断白名单,再判断黑名单 + - sk: "key_for_replicate" + type: replicate + allow: ["mistralai/mixtral-8x7b-instruct-v0.1"] diff --git a/main.go b/main.go index d43e850..5b624cd 100644 --- a/main.go +++ b/main.go @@ -123,9 +123,9 @@ func main() { } for index, upstream := range config.Upstreams { - if upstream.Endpoint == "" || upstream.SK == "" { + if upstream.SK == "" { sendCORSHeaders(c) - c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild upstream '%s' '%s'", upstream.SK, upstream.Endpoint)) + c.AbortWithError(500, fmt.Errorf("[processRequest.begin]: invaild SK (secret key) '%s'", upstream.SK)) continue } @@ -135,7 +135,14 @@ func main() { upstream.Timeout = 120 } - err = processRequest(c, &upstream, &record, shouldResponse) + if upstream.Type == "replicate" { + err = processReplicateRequest(c, &upstream, &record, shouldResponse) + } else if upstream.Type == "openai" { + err = processRequest(c, &upstream, &record, shouldResponse) + } else { + err = fmt.Errorf("[processRequest.begin]: unsupported upstream type '%s'", upstream.Type) + } + if err != nil { if err == http.ErrAbortHandler { abortErr := "[processRequest.done]: AbortHandler, client's connection lost?, no upstream will try, stop here" diff --git a/replicate.go b/replicate.go new file mode 100644 index 0000000..db14f97 --- /dev/null +++ b/replicate.go @@ -0,0 +1,323 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +var replicate_model_url_template = "https://api.replicate.com/v1/models/%s/predictions" + +func processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { + err := _processReplicateRequest(c, upstream, record, shouldResponse) + if shouldResponse { + if err != nil { + c.AbortWithError(502, err) + } + } + return err +} + +func _processReplicateRequest(c *gin.Context, upstream *OPENAI_UPSTREAM, record *Record, shouldResponse bool) error { + // read request body + inBody, err := io.ReadAll(c.Request.Body) + if err != nil { + return errors.New("[processReplicateRequest]: failed to read request body " + err.Error()) + } + + // record request body + record.Body = string(inBody) + + // parse request body + inRequest := &OpenAIChatRequest{} + err = json.Unmarshal(inBody, inRequest) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to parse request body " + err.Error()) + } + + record.Model = inRequest.Model + + // check allow model + if len(upstream.Allow) > 0 { + isAllow := false + for _, model := range upstream.Allow { + if model == inRequest.Model { + isAllow = true + break + } + } + if !isAllow { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: model not allow") + } + } + // check block model + if len(upstream.Deny) > 0 { + for _, model := range upstream.Deny { + if model == inRequest.Model { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: model deny") + } + } + } + + // set url + model_url := fmt.Sprintf(replicate_model_url_template, inRequest.Model) + log.Println("[processReplicateRequest]: model_url:", model_url) + + // create request with default value + outRequest := &ReplicateModelRequest{ + Stream: false, + Input: ReplicateModelRequestInput{ + Prompt: "", + MaxNewTokens: 1024, + Temperature: 0.6, + Top_p: 0.9, + Top_k: 50, + FrequencyPenalty: 0.0, + PresencePenalty: 0.0, + PromptTemplate: "[INST] {prompt} [/INST] ", + }, + } + + // copy value from input request + outRequest.Stream = inRequest.Stream + outRequest.Input.Temperature = inRequest.Temperature + outRequest.Input.FrequencyPenalty = inRequest.FrequencyPenalty + outRequest.Input.PresencePenalty = inRequest.PresencePenalty + + // render prompt + for _, message := range inRequest.Messages { + outRequest.Input.Prompt += message.Content + "\n" + } + + // send request + outBody, err := json.Marshal(outRequest) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to marshal request body " + err.Error()) + } + + // http add headers + req, err := http.NewRequest("POST", model_url, bytes.NewBuffer(outBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Token "+upstream.SK) + // send + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to post request " + err.Error()) + } + + // read response body + outBody, err = io.ReadAll(resp.Body) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) + } + + // parse reponse body + outResponse := &ReplicateModelResponse{} + err = json.Unmarshal(outBody, outResponse) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error()) + } + + if outResponse.Stream { + // get result + log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Stream) + req, err := http.NewRequest("GET", outResponse.URLS.Stream, nil) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to create get request " + err.Error()) + } + req.Header.Set("Authorization", "Token "+upstream.SK) + req.Header.Set("Accept", "text/event-stream") + // send + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to get request " + err.Error()) + } + + // get result by chunk + var buffer string = "" + var indexCount int64 = 0 + for { + if !strings.Contains(buffer, "\n\n") { + // receive chunk + chunk := make([]byte, 1024) + length, err := resp.Body.Read(chunk) + if err == io.EOF { + break + } + if length == 0 { + break + } + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) + } + // add chunk to buffer + chunk = bytes.Trim(chunk, "\x00") + buffer += string(chunk) + continue + } + + // cut the first chunk by find index + index := strings.Index(buffer, "\n\n") + chunk := buffer[:index] + buffer = buffer[index+2:] + + // trim line + chunk = strings.Trim(chunk, "\n") + + // ignore hi + if !strings.Contains(chunk, "\n") { + continue + } + + // parse chunk to ReplicateModelResultChunk object + chunkObj := &ReplicateModelResultChunk{} + log.Println("[processReplicateRequest]: chunk:", chunk) + lines := strings.Split(chunk, "\n") + log.Println("[processReplicateRequest]: lines:", lines) + // first line is event + chunkObj.Event = strings.TrimSpace(lines[0]) + chunkObj.Event = strings.TrimPrefix(chunkObj.Event, "event: ") + fmt.Printf("[processReplicateRequest]: chunkObj.Event: '%s'\n", chunkObj.Event) + fmt.Printf("Length: %d\n", len(chunkObj.Event)) + // second line is id + chunkObj.ID = strings.TrimSpace(lines[1]) + chunkObj.ID = strings.TrimPrefix(chunkObj.ID, "id: ") + chunkObj.ID = strings.SplitN(chunkObj.ID, ":", 2)[0] + // third line is data + chunkObj.Data = lines[2] + chunkObj.Data = strings.TrimPrefix(chunkObj.Data, "data: ") + + record.Response += chunkObj.Data + + // done + if chunkObj.Event == "done" { + break + } + + // create OpenAI response chunk + c.SSEvent("", &OpenAIChatResponse{ + ID: chunkObj.Event, + Model: outResponse.Model, + Choices: []OpenAIChatResponseChoice{ + { + Index: indexCount, + Message: OpenAIChatMessage{ + Role: "assistant", + Content: chunkObj.Data, + }, + }, + }, + }) + c.Writer.Flush() + indexCount += 1 + } + c.SSEvent("", &OpenAIChatResponse{ + ID: "", + Model: outResponse.Model, + Choices: []OpenAIChatResponseChoice{ + { + Index: indexCount, + Message: OpenAIChatMessage{ + Role: "assistant", + Content: "", + }, + FinishReason: "stop", + }, + }, + }) + c.Writer.Flush() + indexCount += 1 + record.Status = 200 + return nil + + } else { + var result *ReplicateModelResultGet + + for { + // get result + log.Println("[processReplicateRequest]: outResponse.URLS.Get:", outResponse.URLS.Get) + req, err := http.NewRequest("GET", outResponse.URLS.Get, nil) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to create get request " + err.Error()) + } + req.Header.Set("Authorization", "Token "+upstream.SK) + // send + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to get request " + err.Error()) + } + // get result + resultBody, err := io.ReadAll(resp.Body) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to read response body " + err.Error()) + } + + log.Println("[processReplicateRequest]: resultBody:", string(resultBody)) + + // parse reponse body + result = &ReplicateModelResultGet{} + err = json.Unmarshal(resultBody, result) + if err != nil { + c.Request.Body = io.NopCloser(bytes.NewBuffer(inBody)) + return errors.New("[processReplicateRequest]: failed to parse response body " + err.Error()) + } + log.Println("[processReplicateRequest]: result:", result) + + if result.Status == "processing" || result.Status == "starting" { + time.Sleep(3 * time.Second) + continue + } + + break + } + + // build openai resposne + openAIResult := &OpenAIChatResponse{ + ID: result.ID, + Model: result.Model, + Choices: []OpenAIChatResponseChoice{}, + Usage: OpenAIChatResponseUsage{ + TotalTokens: result.Metrics.InputTokenCount + result.Metrics.OutputTokenCount, + PromptTokens: result.Metrics.InputTokenCount, + }, + } + openAIResult.Choices = append(openAIResult.Choices, OpenAIChatResponseChoice{ + Index: 0, + Message: OpenAIChatMessage{ + Role: "assistant", + Content: strings.Join(result.Output, ""), + }, + FinishReason: "stop", + }) + + record.Body = strings.Join(result.Output, "") + record.Status = 200 + + // gin return + c.JSON(200, openAIResult) + + } + + return nil +} diff --git a/structure.go b/structure.go index d1ffe6d..72d7a8e 100644 --- a/structure.go +++ b/structure.go @@ -22,6 +22,7 @@ type OPENAI_UPSTREAM struct { Timeout int64 `yaml:"timeout"` Allow []string `yaml:"allow"` Deny []string `yaml:"deny"` + Type string `yaml:"type"` URL *url.URL } @@ -61,7 +62,119 @@ func readConfig(filepath string) Config { log.Fatalf("Can't parse upstream endpoint URL '%s': %s", upstream.Endpoint, err) } config.Upstreams[i].URL = endpoint + if upstream.Type == "" { + upstream.Type = "openai" + } + if (upstream.Type != "openai") && (upstream.Type != "replicate") { + log.Fatalf("Unsupported upstream type '%s'", upstream.Type) + } } return config } + +type OpenAIChatRequest struct { + FrequencyPenalty float64 `json:"frequency_penalty"` + PresencePenalty float64 `json:"presence_penalty"` + MaxTokens int64 `json:"max_tokens"` + Model string `json:"model"` + Stream bool `json:"stream"` + Temperature float64 `json:"temperature"` + Messages []OpenAIChatRequestMessage +} + +type OpenAIChatRequestMessage struct { + Content string `json:"content"` + User string `json:"user"` +} + +type ReplicateModelRequest struct { + Stream bool `json:"stream"` + Input ReplicateModelRequestInput `json:"input"` +} + +type ReplicateModelRequestInput struct { + Prompt string `json:"prompt"` + MaxNewTokens int64 `json:"max_new_tokens"` + Temperature float64 `json:"temperature"` + Top_p float64 `json:"top_p"` + Top_k int64 `json:"top_k"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` + PromptTemplate string `json:"prompt_template"` +} + +type ReplicateModelResponse struct { + Model string `json:"model"` + Version string `json:"version"` + Stream bool `json:"stream"` + Error string `json:"error"` + URLS ReplicateModelResponseURLS `json:"urls"` +} + +type ReplicateModelResponseURLS struct { + Cancel string `json:"cancel"` + Get string `json:"get"` + Stream string `json:"stream"` +} + +type ReplicateModelResultGet struct { + ID string `json:"id"` + Model string `json:"model"` + Version string `json:"version"` + Output []string `json:"output"` + Error string `json:"error"` + Metrics ReplicateModelResultMetrics `json:"metrics"` + Status string `json:"status"` +} + +type ReplicateModelResultMetrics struct { + InputTokenCount int64 `json:"input_token_count"` + OutputTokenCount int64 `json:"output_token_count"` +} + +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []OpenAIChatResponseChoice `json:"choices"` + Usage OpenAIChatResponseUsage `json:"usage"` +} + +type OpenAIChatResponseUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +type OpenAIChatResponseChoice struct { + Index int64 `json:"index"` + Message OpenAIChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type OpenAIChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ReplicateModelResultChunk struct { + Event string `json:"event"` + ID string `json:"id"` + Data string `json:"data"` +} + +type OpenAIChatResponseChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []OpenAIChatResponseChunkChoice `json:"choices"` +} + +type OpenAIChatResponseChunkChoice struct { + Index int64 `json:"index"` + Delta OpenAIChatMessage `json:"delta"` + FinishReason *string `json:"finish_reason"` +}