This commit is contained in:
2024-01-15 12:36:42 +08:00
commit dabdbb42de
11 changed files with 306 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
__pycache__
/venv

11
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"*_test.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}

53
README.md Normal file
View File

@@ -0,0 +1,53 @@
# Embedding API 后端服务
独立 API 服务,为分析提供 embedding 支持
## 配置虚拟环境
```
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
./vent/bin/activate
# 安装依赖
pip install -r requirements_version.txt
# (如果没有代理,使用国内镜像安装依赖)
pip install -r requirements_version.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 运行
```bash
python main.py --host 0.0.0.0 --port 7999
```
## 使用服务
支持的 model: `acge-large-zh``text-embedding-ada-002`
curl 示例
```bashg
curl http://localhost:7999/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"input": "The food was delicious and the waiter...",
"model": "acge-large-zh"
}'
```
python 示例
```python
from openai import OpenAI
client = OpenAI(base_url="http://localhost:7999/v1", api_key='whatever')
client.embeddings.create(
model="acge-large-zh",
input="The food was delicious and the waiter..."
)
```
详细 API 文档位于 <http://localhost:7999/docs>

51
acge_embedding.py Normal file
View File

@@ -0,0 +1,51 @@
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
model_name = "aspire/acge-large-zh"
print("Loading model", model_name)
model = (
AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).eval().to(device)
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model", model_name, "loaded!")
@torch.no_grad()
def acge_embedding(text: list[str]) -> list[list[float]]:
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
if len(text) > 1000:
raise ValueError("Input text too long!", len(text))
batch_data = tokenizer(
text=text,
padding="longest",
return_tensors="pt",
# max_length=1024,
truncation=False,
)
# 检查是否有超长的文本
if batch_data["input_ids"].shape[1] > 1024:
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
# 测试结果10000 条文本,显存占用 3.5G,速度 3s显存可能不会自动回收
batch_data = batch_data.to(device)
attention_mask = batch_data["attention_mask"]
model_output = model(**batch_data)
last_hidden = model_output.last_hidden_state.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
vector = normalize(
vector.cpu().detach().numpy(),
norm="l2",
axis=1,
)
return vector

58
app.py Normal file
View File

@@ -0,0 +1,58 @@
import fastapi
import pydantic
from typing import Literal
from acge_embedding import acge_embedding
app = fastapi.FastAPI()
class EmbeddingAPIRequest(pydantic.BaseModel):
input: str | list[str]
model: Literal["acge-large-zh", "text-embedding-ada-002"]
class EmbeddingAPIResposne(pydantic.BaseModel):
class Data(pydantic.BaseModel):
object: Literal["embedding"]
embedding: list[float] = pydantic.Field(
description="1024 或 1536 维度的向量,不同模型维度不同"
)
index: int
data: list[Data]
object: Literal["list"]
model: Literal["acge-large-zh", "text-embedding-ada-002"]
usage: dict[str, int] = {}
@app.post("/v1/embeddings")
async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne:
# 将字符串统一转换成列表后续进行 batch 处理
if isinstance(req.input, str):
req.input = [req.input]
# 进行 embedding 计算
embeddings: list[float] = []
if req.model == "acge-large-zh":
embeddings = acge_embedding(req.input).tolist()
elif req.model == "text-embedding-ada-002":
# [TODO]: Implement text-embedding-ada-002
raise NotImplementedError("text-embedding-ada-002 not implemented yet!")
# 与 OpenAI 接口返回格式一致
# https://platform.openai.com/docs/api-reference/embeddings/create
return EmbeddingAPIResposne.model_validate(
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": e,
"index": i,
}
for i, e in enumerate(embeddings)
],
"model": req.model,
"usage": {},
}
)

4
deno.json Normal file
View File

@@ -0,0 +1,4 @@
{
"tasks": {
}
}

14
main.py Normal file
View File

@@ -0,0 +1,14 @@
import argparse
import uvicorn
args = argparse.ArgumentParser()
args.add_argument("--port", type=int, default=7999)
args.add_argument("--host", type=str, default="0.0.0.0")
args.add_argument("--reload", action="store_true")
args = args.parse_args()
if __name__ == "__main__":
print("Start serving on", args.host, args.port)
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)

7
requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
openai
uvicorn[standard]
fastapi
pydantic
scikit-learn
torch
transformers

59
requirements_version.txt Normal file
View File

@@ -0,0 +1,59 @@
annotated-types==0.6.0
anyio==4.2.0
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
distro==1.9.0
fastapi==0.109.0
filelock==3.13.1
fsspec==2023.12.2
h11==0.14.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.26.0
huggingface-hub==0.20.2
idna==3.6
Jinja2==3.1.3
joblib==1.3.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
openai==1.7.2
packaging==23.2
pydantic==2.5.3
pydantic_core==2.14.6
python-dotenv==1.0.0
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
safetensors==0.4.1
scikit-learn==1.3.2
scipy==1.11.4
sniffio==1.3.0
starlette==0.35.1
sympy==1.12
threadpoolctl==3.2.0
tokenizers==0.15.0
torch==2.1.2
tqdm==4.66.1
transformers==4.36.2
triton==2.1.0
typing_extensions==4.9.0
urllib3==2.1.0
uvicorn==0.25.0
uvloop==0.19.0
watchfiles==0.21.0
websockets==12.0

21
tests/deno.ts Normal file
View File

@@ -0,0 +1,21 @@
const url = "http://10.39.39.9:7999/v1/embeddings";
const input: string[] = [];
for (let i = 0; i < 1000; i++) {
input.push("我是一名大学生");
}
const resp = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
model: "acge-large-zh",
input,
}),
});
const result = await resp.json();
console.log(result);

View File

@@ -0,0 +1,26 @@
"""
测试 embedding 接口与 OpenAI 模块兼容性
需要将 embedding 接口提前部署在 localhost:7999/v1/embeddings
"""
import requests
import unittest
import openai
url = "http://localhost:7999/v1/embeddings"
class TestOpenAI(unittest.IsolatedAsyncioTestCase):
async def testOpenAIEmbedding(self):
client = openai.OpenAI(
api_key="mikumikumi", base_url="http://localhost:7999/v1"
)
result = client.embeddings.create(
model="acge-large-zh",
input=["今天天气不错", "明天天气也不错"],
)
for i, data in enumerate(result.data):
# acge 模型的 embedding 与 OpenAI 模型的 embedding 有一定差异
# acge 向量维度为 1024OpenAI 向量维度为 1536
self.assertEqual(len(data.embedding), 1024)