From 70a08a9ed1c4e445930bced2d83c832ee9b2c376 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Mon, 15 Jan 2024 14:21:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20embedding=20=E4=B8=8E=20model=5Fname=20?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- acge_embedding.py | 2 +- app.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/acge_embedding.py b/acge_embedding.py index de8ffab..13ce3d0 100644 --- a/acge_embedding.py +++ b/acge_embedding.py @@ -48,4 +48,4 @@ def acge_embedding(text: list[str]) -> list[list[float]]: norm="l2", axis=1, ) - return vector + return vector.tolist() diff --git a/app.py b/app.py index cbb1686..fd9525a 100644 --- a/app.py +++ b/app.py @@ -32,12 +32,14 @@ async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne: req.input = [req.input] # 进行 embedding 计算 - embeddings: list[float] = [] + embeddings: list[list[float]] = [] if req.model == "acge-large-zh": - embeddings = acge_embedding(req.input).tolist() + embeddings = acge_embedding(req.input) elif req.model == "text-embedding-ada-002": # [TODO]: Implement text-embedding-ada-002 raise NotImplementedError("text-embedding-ada-002 not implemented yet!") + else: + raise ValueError("Unknown model name!") # 与 OpenAI 接口返回格式一致 # https://platform.openai.com/docs/api-reference/embeddings/create