import fastapi import pydantic from typing import Literal from acge_embedding import acge_embedding app = fastapi.FastAPI() class EmbeddingAPIRequest(pydantic.BaseModel): input: str | list[str] model: Literal["acge-large-zh", "text-embedding-ada-002"] class EmbeddingAPIResposne(pydantic.BaseModel): class Data(pydantic.BaseModel): object: Literal["embedding"] embedding: list[float] = pydantic.Field( description="1024 或 1536 维度的向量,不同模型维度不同" ) index: int data: list[Data] object: Literal["list"] model: Literal["acge-large-zh", "text-embedding-ada-002"] usage: dict[str, int] = {} @app.post("/v1/embeddings") async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne: # 将字符串统一转换成列表后续进行 batch 处理 if isinstance(req.input, str): req.input = [req.input] # 进行 embedding 计算 embeddings: list[list[float]] = [] if req.model == "acge-large-zh": embeddings = acge_embedding(req.input) elif req.model == "text-embedding-ada-002": # [TODO]: Implement text-embedding-ada-002 raise NotImplementedError("text-embedding-ada-002 not implemented yet!") else: raise ValueError("Unknown model name!") # 与 OpenAI 接口返回格式一致 # https://platform.openai.com/docs/api-reference/embeddings/create return EmbeddingAPIResposne.model_validate( { "object": "list", "data": [ { "object": "embedding", "embedding": e, "index": i, } for i, e in enumerate(embeddings) ], "model": req.model, "usage": {}, } )