init
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
__pycache__
|
||||
/venv
|
||||
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"python.testing.unittestArgs": [
|
||||
"-v",
|
||||
"-s",
|
||||
"./tests",
|
||||
"-p",
|
||||
"*_test.py"
|
||||
],
|
||||
"python.testing.pytestEnabled": false,
|
||||
"python.testing.unittestEnabled": true
|
||||
}
|
||||
53
README.md
Normal file
53
README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Embedding API 后端服务
|
||||
|
||||
独立 API 服务,为分析提供 embedding 支持
|
||||
|
||||
## 配置虚拟环境
|
||||
|
||||
```
|
||||
# 创建虚拟环境
|
||||
python -m venv venv
|
||||
|
||||
# 激活虚拟环境
|
||||
./vent/bin/activate
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements_version.txt
|
||||
# (如果没有代理,使用国内镜像安装依赖)
|
||||
pip install -r requirements_version.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
## 运行
|
||||
|
||||
```bash
|
||||
python main.py --host 0.0.0.0 --port 7999
|
||||
```
|
||||
|
||||
## 使用服务
|
||||
|
||||
支持的 model: `acge-large-zh` 与 `text-embedding-ada-002`
|
||||
|
||||
curl 示例
|
||||
|
||||
```bashg
|
||||
curl http://localhost:7999/v1/embeddings \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input": "The food was delicious and the waiter...",
|
||||
"model": "acge-large-zh"
|
||||
}'
|
||||
```
|
||||
|
||||
python 示例
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(base_url="http://localhost:7999/v1", api_key='whatever')
|
||||
|
||||
client.embeddings.create(
|
||||
model="acge-large-zh",
|
||||
input="The food was delicious and the waiter..."
|
||||
)
|
||||
```
|
||||
|
||||
详细 API 文档位于 <http://localhost:7999/docs>
|
||||
51
acge_embedding.py
Normal file
51
acge_embedding.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from sklearn.preprocessing import normalize
|
||||
import torch
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Using device:", device)
|
||||
|
||||
model_name = "aspire/acge-large-zh"
|
||||
print("Loading model", model_name)
|
||||
model = (
|
||||
AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).eval().to(device)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
print("Model", model_name, "loaded!")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def acge_embedding(text: list[str]) -> list[list[float]]:
|
||||
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
|
||||
if len(text) > 1000:
|
||||
raise ValueError("Input text too long!", len(text))
|
||||
|
||||
batch_data = tokenizer(
|
||||
text=text,
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
# max_length=1024,
|
||||
truncation=False,
|
||||
)
|
||||
|
||||
# 检查是否有超长的文本
|
||||
if batch_data["input_ids"].shape[1] > 1024:
|
||||
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
|
||||
|
||||
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
|
||||
# 测试结果:10000 条文本,显存占用 3.5G,速度 3s,显存可能不会自动回收
|
||||
|
||||
batch_data = batch_data.to(device)
|
||||
attention_mask = batch_data["attention_mask"]
|
||||
model_output = model(**batch_data)
|
||||
last_hidden = model_output.last_hidden_state.masked_fill(
|
||||
~attention_mask[..., None].bool(), 0.0
|
||||
)
|
||||
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
vector = normalize(
|
||||
vector.cpu().detach().numpy(),
|
||||
norm="l2",
|
||||
axis=1,
|
||||
)
|
||||
return vector
|
||||
58
app.py
Normal file
58
app.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import fastapi
|
||||
import pydantic
|
||||
from typing import Literal
|
||||
from acge_embedding import acge_embedding
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
|
||||
|
||||
class EmbeddingAPIRequest(pydantic.BaseModel):
|
||||
input: str | list[str]
|
||||
model: Literal["acge-large-zh", "text-embedding-ada-002"]
|
||||
|
||||
|
||||
class EmbeddingAPIResposne(pydantic.BaseModel):
|
||||
class Data(pydantic.BaseModel):
|
||||
object: Literal["embedding"]
|
||||
embedding: list[float] = pydantic.Field(
|
||||
description="1024 或 1536 维度的向量,不同模型维度不同"
|
||||
)
|
||||
index: int
|
||||
|
||||
data: list[Data]
|
||||
object: Literal["list"]
|
||||
model: Literal["acge-large-zh", "text-embedding-ada-002"]
|
||||
usage: dict[str, int] = {}
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne:
|
||||
# 将字符串统一转换成列表后续进行 batch 处理
|
||||
if isinstance(req.input, str):
|
||||
req.input = [req.input]
|
||||
|
||||
# 进行 embedding 计算
|
||||
embeddings: list[float] = []
|
||||
if req.model == "acge-large-zh":
|
||||
embeddings = acge_embedding(req.input).tolist()
|
||||
elif req.model == "text-embedding-ada-002":
|
||||
# [TODO]: Implement text-embedding-ada-002
|
||||
raise NotImplementedError("text-embedding-ada-002 not implemented yet!")
|
||||
|
||||
# 与 OpenAI 接口返回格式一致
|
||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
return EmbeddingAPIResposne.model_validate(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": e,
|
||||
"index": i,
|
||||
}
|
||||
for i, e in enumerate(embeddings)
|
||||
],
|
||||
"model": req.model,
|
||||
"usage": {},
|
||||
}
|
||||
)
|
||||
14
main.py
Normal file
14
main.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--port", type=int, default=7999)
|
||||
args.add_argument("--host", type=str, default="0.0.0.0")
|
||||
args.add_argument("--reload", action="store_true")
|
||||
|
||||
args = args.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start serving on", args.host, args.port)
|
||||
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
|
||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
openai
|
||||
uvicorn[standard]
|
||||
fastapi
|
||||
pydantic
|
||||
scikit-learn
|
||||
torch
|
||||
transformers
|
||||
59
requirements_version.txt
Normal file
59
requirements_version.txt
Normal file
@@ -0,0 +1,59 @@
|
||||
annotated-types==0.6.0
|
||||
anyio==4.2.0
|
||||
certifi==2023.11.17
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
distro==1.9.0
|
||||
fastapi==0.109.0
|
||||
filelock==3.13.1
|
||||
fsspec==2023.12.2
|
||||
h11==0.14.0
|
||||
httpcore==1.0.2
|
||||
httptools==0.6.1
|
||||
httpx==0.26.0
|
||||
huggingface-hub==0.20.2
|
||||
idna==3.6
|
||||
Jinja2==3.1.3
|
||||
joblib==1.3.2
|
||||
MarkupSafe==2.1.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.2.1
|
||||
numpy==1.26.3
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.18.1
|
||||
nvidia-nvjitlink-cu12==12.3.101
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
openai==1.7.2
|
||||
packaging==23.2
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
python-dotenv==1.0.0
|
||||
PyYAML==6.0.1
|
||||
regex==2023.12.25
|
||||
requests==2.31.0
|
||||
safetensors==0.4.1
|
||||
scikit-learn==1.3.2
|
||||
scipy==1.11.4
|
||||
sniffio==1.3.0
|
||||
starlette==0.35.1
|
||||
sympy==1.12
|
||||
threadpoolctl==3.2.0
|
||||
tokenizers==0.15.0
|
||||
torch==2.1.2
|
||||
tqdm==4.66.1
|
||||
transformers==4.36.2
|
||||
triton==2.1.0
|
||||
typing_extensions==4.9.0
|
||||
urllib3==2.1.0
|
||||
uvicorn==0.25.0
|
||||
uvloop==0.19.0
|
||||
watchfiles==0.21.0
|
||||
websockets==12.0
|
||||
21
tests/deno.ts
Normal file
21
tests/deno.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
const url = "http://10.39.39.9:7999/v1/embeddings";
|
||||
|
||||
const input: string[] = [];
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
input.push("我是一名大学生");
|
||||
}
|
||||
|
||||
const resp = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "acge-large-zh",
|
||||
input,
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await resp.json();
|
||||
|
||||
console.log(result);
|
||||
26
tests/openai_embedding_test.py
Normal file
26
tests/openai_embedding_test.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
测试 embedding 接口与 OpenAI 模块兼容性
|
||||
|
||||
需要将 embedding 接口提前部署在 localhost:7999/v1/embeddings
|
||||
"""
|
||||
|
||||
import requests
|
||||
import unittest
|
||||
import openai
|
||||
|
||||
url = "http://localhost:7999/v1/embeddings"
|
||||
|
||||
|
||||
class TestOpenAI(unittest.IsolatedAsyncioTestCase):
|
||||
async def testOpenAIEmbedding(self):
|
||||
client = openai.OpenAI(
|
||||
api_key="mikumikumi", base_url="http://localhost:7999/v1"
|
||||
)
|
||||
result = client.embeddings.create(
|
||||
model="acge-large-zh",
|
||||
input=["今天天气不错", "明天天气也不错"],
|
||||
)
|
||||
for i, data in enumerate(result.data):
|
||||
# acge 模型的 embedding 与 OpenAI 模型的 embedding 有一定差异
|
||||
# acge 向量维度为 1024,OpenAI 向量维度为 1536
|
||||
self.assertEqual(len(data.embedding), 1024)
|
||||
Reference in New Issue
Block a user