This commit is contained in:
2024-01-15 12:36:42 +08:00
commit dabdbb42de
11 changed files with 306 additions and 0 deletions

51
acge_embedding.py Normal file
View File

@@ -0,0 +1,51 @@
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
model_name = "aspire/acge-large-zh"
print("Loading model", model_name)
model = (
AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).eval().to(device)
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model", model_name, "loaded!")
@torch.no_grad()
def acge_embedding(text: list[str]) -> list[list[float]]:
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
if len(text) > 1000:
raise ValueError("Input text too long!", len(text))
batch_data = tokenizer(
text=text,
padding="longest",
return_tensors="pt",
# max_length=1024,
truncation=False,
)
# 检查是否有超长的文本
if batch_data["input_ids"].shape[1] > 1024:
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
# 测试结果10000 条文本,显存占用 3.5G,速度 3s显存可能不会自动回收
batch_data = batch_data.to(device)
attention_mask = batch_data["attention_mask"]
model_output = model(**batch_data)
last_hidden = model_output.last_hidden_state.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
vector = normalize(
vector.cpu().detach().numpy(),
norm="l2",
axis=1,
)
return vector