init
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
__pycache__
|
||||||
|
/venv
|
||||||
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"python.testing.unittestArgs": [
|
||||||
|
"-v",
|
||||||
|
"-s",
|
||||||
|
"./tests",
|
||||||
|
"-p",
|
||||||
|
"*_test.py"
|
||||||
|
],
|
||||||
|
"python.testing.pytestEnabled": false,
|
||||||
|
"python.testing.unittestEnabled": true
|
||||||
|
}
|
||||||
53
README.md
Normal file
53
README.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# Embedding API 后端服务
|
||||||
|
|
||||||
|
独立 API 服务,为分析提供 embedding 支持
|
||||||
|
|
||||||
|
## 配置虚拟环境
|
||||||
|
|
||||||
|
```
|
||||||
|
# 创建虚拟环境
|
||||||
|
python -m venv venv
|
||||||
|
|
||||||
|
# 激活虚拟环境
|
||||||
|
./vent/bin/activate
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements_version.txt
|
||||||
|
# (如果没有代理,使用国内镜像安装依赖)
|
||||||
|
pip install -r requirements_version.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --host 0.0.0.0 --port 7999
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用服务
|
||||||
|
|
||||||
|
支持的 model: `acge-large-zh` 与 `text-embedding-ada-002`
|
||||||
|
|
||||||
|
curl 示例
|
||||||
|
|
||||||
|
```bashg
|
||||||
|
curl http://localhost:7999/v1/embeddings \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"input": "The food was delicious and the waiter...",
|
||||||
|
"model": "acge-large-zh"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
python 示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
client = OpenAI(base_url="http://localhost:7999/v1", api_key='whatever')
|
||||||
|
|
||||||
|
client.embeddings.create(
|
||||||
|
model="acge-large-zh",
|
||||||
|
input="The food was delicious and the waiter..."
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
详细 API 文档位于 <http://localhost:7999/docs>
|
||||||
51
acge_embedding.py
Normal file
51
acge_embedding.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from sklearn.preprocessing import normalize
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print("Using device:", device)
|
||||||
|
|
||||||
|
model_name = "aspire/acge-large-zh"
|
||||||
|
print("Loading model", model_name)
|
||||||
|
model = (
|
||||||
|
AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).eval().to(device)
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
print("Model", model_name, "loaded!")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def acge_embedding(text: list[str]) -> list[list[float]]:
|
||||||
|
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
|
||||||
|
if len(text) > 1000:
|
||||||
|
raise ValueError("Input text too long!", len(text))
|
||||||
|
|
||||||
|
batch_data = tokenizer(
|
||||||
|
text=text,
|
||||||
|
padding="longest",
|
||||||
|
return_tensors="pt",
|
||||||
|
# max_length=1024,
|
||||||
|
truncation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否有超长的文本
|
||||||
|
if batch_data["input_ids"].shape[1] > 1024:
|
||||||
|
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
|
||||||
|
|
||||||
|
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
|
||||||
|
# 测试结果:10000 条文本,显存占用 3.5G,速度 3s,显存可能不会自动回收
|
||||||
|
|
||||||
|
batch_data = batch_data.to(device)
|
||||||
|
attention_mask = batch_data["attention_mask"]
|
||||||
|
model_output = model(**batch_data)
|
||||||
|
last_hidden = model_output.last_hidden_state.masked_fill(
|
||||||
|
~attention_mask[..., None].bool(), 0.0
|
||||||
|
)
|
||||||
|
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
vector = normalize(
|
||||||
|
vector.cpu().detach().numpy(),
|
||||||
|
norm="l2",
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
return vector
|
||||||
58
app.py
Normal file
58
app.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import fastapi
|
||||||
|
import pydantic
|
||||||
|
from typing import Literal
|
||||||
|
from acge_embedding import acge_embedding
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingAPIRequest(pydantic.BaseModel):
|
||||||
|
input: str | list[str]
|
||||||
|
model: Literal["acge-large-zh", "text-embedding-ada-002"]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingAPIResposne(pydantic.BaseModel):
|
||||||
|
class Data(pydantic.BaseModel):
|
||||||
|
object: Literal["embedding"]
|
||||||
|
embedding: list[float] = pydantic.Field(
|
||||||
|
description="1024 或 1536 维度的向量,不同模型维度不同"
|
||||||
|
)
|
||||||
|
index: int
|
||||||
|
|
||||||
|
data: list[Data]
|
||||||
|
object: Literal["list"]
|
||||||
|
model: Literal["acge-large-zh", "text-embedding-ada-002"]
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
async def embedding_api(req: EmbeddingAPIRequest) -> EmbeddingAPIResposne:
|
||||||
|
# 将字符串统一转换成列表后续进行 batch 处理
|
||||||
|
if isinstance(req.input, str):
|
||||||
|
req.input = [req.input]
|
||||||
|
|
||||||
|
# 进行 embedding 计算
|
||||||
|
embeddings: list[float] = []
|
||||||
|
if req.model == "acge-large-zh":
|
||||||
|
embeddings = acge_embedding(req.input).tolist()
|
||||||
|
elif req.model == "text-embedding-ada-002":
|
||||||
|
# [TODO]: Implement text-embedding-ada-002
|
||||||
|
raise NotImplementedError("text-embedding-ada-002 not implemented yet!")
|
||||||
|
|
||||||
|
# 与 OpenAI 接口返回格式一致
|
||||||
|
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
|
return EmbeddingAPIResposne.model_validate(
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": e,
|
||||||
|
"index": i,
|
||||||
|
}
|
||||||
|
for i, e in enumerate(embeddings)
|
||||||
|
],
|
||||||
|
"model": req.model,
|
||||||
|
"usage": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
14
main.py
Normal file
14
main.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import argparse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
args = argparse.ArgumentParser()
|
||||||
|
args.add_argument("--port", type=int, default=7999)
|
||||||
|
args.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
args.add_argument("--reload", action="store_true")
|
||||||
|
|
||||||
|
args = args.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Start serving on", args.host, args.port)
|
||||||
|
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
openai
|
||||||
|
uvicorn[standard]
|
||||||
|
fastapi
|
||||||
|
pydantic
|
||||||
|
scikit-learn
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
59
requirements_version.txt
Normal file
59
requirements_version.txt
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
annotated-types==0.6.0
|
||||||
|
anyio==4.2.0
|
||||||
|
certifi==2023.11.17
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
click==8.1.7
|
||||||
|
distro==1.9.0
|
||||||
|
fastapi==0.109.0
|
||||||
|
filelock==3.13.1
|
||||||
|
fsspec==2023.12.2
|
||||||
|
h11==0.14.0
|
||||||
|
httpcore==1.0.2
|
||||||
|
httptools==0.6.1
|
||||||
|
httpx==0.26.0
|
||||||
|
huggingface-hub==0.20.2
|
||||||
|
idna==3.6
|
||||||
|
Jinja2==3.1.3
|
||||||
|
joblib==1.3.2
|
||||||
|
MarkupSafe==2.1.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.2.1
|
||||||
|
numpy==1.26.3
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
nvidia-nccl-cu12==2.18.1
|
||||||
|
nvidia-nvjitlink-cu12==12.3.101
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
openai==1.7.2
|
||||||
|
packaging==23.2
|
||||||
|
pydantic==2.5.3
|
||||||
|
pydantic_core==2.14.6
|
||||||
|
python-dotenv==1.0.0
|
||||||
|
PyYAML==6.0.1
|
||||||
|
regex==2023.12.25
|
||||||
|
requests==2.31.0
|
||||||
|
safetensors==0.4.1
|
||||||
|
scikit-learn==1.3.2
|
||||||
|
scipy==1.11.4
|
||||||
|
sniffio==1.3.0
|
||||||
|
starlette==0.35.1
|
||||||
|
sympy==1.12
|
||||||
|
threadpoolctl==3.2.0
|
||||||
|
tokenizers==0.15.0
|
||||||
|
torch==2.1.2
|
||||||
|
tqdm==4.66.1
|
||||||
|
transformers==4.36.2
|
||||||
|
triton==2.1.0
|
||||||
|
typing_extensions==4.9.0
|
||||||
|
urllib3==2.1.0
|
||||||
|
uvicorn==0.25.0
|
||||||
|
uvloop==0.19.0
|
||||||
|
watchfiles==0.21.0
|
||||||
|
websockets==12.0
|
||||||
21
tests/deno.ts
Normal file
21
tests/deno.ts
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
const url = "http://10.39.39.9:7999/v1/embeddings";
|
||||||
|
|
||||||
|
const input: string[] = [];
|
||||||
|
for (let i = 0; i < 1000; i++) {
|
||||||
|
input.push("我是一名大学生");
|
||||||
|
}
|
||||||
|
|
||||||
|
const resp = await fetch(url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "acge-large-zh",
|
||||||
|
input,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await resp.json();
|
||||||
|
|
||||||
|
console.log(result);
|
||||||
26
tests/openai_embedding_test.py
Normal file
26
tests/openai_embedding_test.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
测试 embedding 接口与 OpenAI 模块兼容性
|
||||||
|
|
||||||
|
需要将 embedding 接口提前部署在 localhost:7999/v1/embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import unittest
|
||||||
|
import openai
|
||||||
|
|
||||||
|
url = "http://localhost:7999/v1/embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAI(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def testOpenAIEmbedding(self):
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="mikumikumi", base_url="http://localhost:7999/v1"
|
||||||
|
)
|
||||||
|
result = client.embeddings.create(
|
||||||
|
model="acge-large-zh",
|
||||||
|
input=["今天天气不错", "明天天气也不错"],
|
||||||
|
)
|
||||||
|
for i, data in enumerate(result.data):
|
||||||
|
# acge 模型的 embedding 与 OpenAI 模型的 embedding 有一定差异
|
||||||
|
# acge 向量维度为 1024,OpenAI 向量维度为 1536
|
||||||
|
self.assertEqual(len(data.embedding), 1024)
|
||||||
Reference in New Issue
Block a user