diff --git a/acge_embedding.py b/acge_embedding.py index de8ffab..13ce3d0 100644 --- a/acge_embedding.py +++ b/acge_embedding.py @@ -48,4 +48,4 @@ def acge_embedding(text: list[str]) -> list[list[float]]: norm="l2", axis=1, ) - return vector + return vector.tolist() diff --git a/app.py b/app.py index cbb1686..fd9525a 100644 --- a/app.py +++ b/app.py @@ -32,12 +32,14 @@ async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne: req.input = [req.input] # 进行 embedding 计算 - embeddings: list[float] = [] + embeddings: list[list[float]] = [] if req.model == "acge-large-zh": - embeddings = acge_embedding(req.input).tolist() + embeddings = acge_embedding(req.input) elif req.model == "text-embedding-ada-002": # [TODO]: Implement text-embedding-ada-002 raise NotImplementedError("text-embedding-ada-002 not implemented yet!") + else: + raise ValueError("Unknown model name!") # 与 OpenAI 接口返回格式一致 # https://platform.openai.com/docs/api-reference/embeddings/create