Files
local-embedding-api/acge_embedding.py
2024-06-04 14:06:40 +08:00

52 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
model_name = "models/aspire--acge-large-text"
print("Loading model", model_name)
model = (
AutoModel.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16).eval().to(device)
)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
print("Model", model_name, "loaded!")
@torch.no_grad()
def acge_embedding(text: list[str]) -> list[list[float]]:
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
if len(text) > 1000:
raise ValueError("Input text too long!", len(text))
batch_data = tokenizer(
text=text,
padding="longest",
return_tensors="pt",
# max_length=1024,
truncation=False,
)
# 检查是否有超长的文本
if batch_data["input_ids"].shape[1] > 1024:
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
# 测试结果10000 条文本,显存占用 3.5G,速度 3s显存可能不会自动回收
batch_data = batch_data.to(device)
attention_mask = batch_data["attention_mask"]
model_output = model(**batch_data)
last_hidden = model_output.last_hidden_state.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
vector = normalize(
vector.cpu().detach().numpy(),
norm="l2",
axis=1,
)
return vector.tolist()