This commit is contained in:
2024-01-15 12:36:42 +08:00
commit dabdbb42de
11 changed files with 306 additions and 0 deletions

58
app.py Normal file
View File

@@ -0,0 +1,58 @@
import fastapi
import pydantic
from typing import Literal
from acge_embedding import acge_embedding
app = fastapi.FastAPI()
class EmbeddingAPIRequest(pydantic.BaseModel):
input: str | list[str]
model: Literal["acge-large-zh", "text-embedding-ada-002"]
class EmbeddingAPIResposne(pydantic.BaseModel):
class Data(pydantic.BaseModel):
object: Literal["embedding"]
embedding: list[float] = pydantic.Field(
description="1024 或 1536 维度的向量,不同模型维度不同"
)
index: int
data: list[Data]
object: Literal["list"]
model: Literal["acge-large-zh", "text-embedding-ada-002"]
usage: dict[str, int] = {}
@app.post("/v1/embeddings")
async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne:
# 将字符串统一转换成列表后续进行 batch 处理
if isinstance(req.input, str):
req.input = [req.input]
# 进行 embedding 计算
embeddings: list[float] = []
if req.model == "acge-large-zh":
embeddings = acge_embedding(req.input).tolist()
elif req.model == "text-embedding-ada-002":
# [TODO]: Implement text-embedding-ada-002
raise NotImplementedError("text-embedding-ada-002 not implemented yet!")
# 与 OpenAI 接口返回格式一致
# https://platform.openai.com/docs/api-reference/embeddings/create
return EmbeddingAPIResposne.model_validate(
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": e,
"index": i,
}
for i, e in enumerate(embeddings)
],
"model": req.model,
"usage": {},
}
)