52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
from transformers import AutoModel, AutoTokenizer
|
||
from sklearn.preprocessing import normalize
|
||
import torch
|
||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print("Using device:", device)
|
||
|
||
model_name = "models/aspire--acge-large-zh"
|
||
print("Loading model", model_name)
|
||
model = (
|
||
AutoModel.from_pretrained(model_name, local_files_only=True).eval().to(device)
|
||
)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
|
||
print("Model", model_name, "loaded!")
|
||
|
||
|
||
@torch.no_grad()
|
||
def acge_embedding(text: list[str]) -> list[list[float]]:
|
||
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
|
||
if len(text) > 1000:
|
||
raise ValueError("Input text too long!", len(text))
|
||
|
||
batch_data = tokenizer(
|
||
text=text,
|
||
padding="longest",
|
||
return_tensors="pt",
|
||
# max_length=1024,
|
||
truncation=False,
|
||
)
|
||
|
||
# 检查是否有超长的文本
|
||
if batch_data["input_ids"].shape[1] > 1024:
|
||
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
|
||
|
||
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
|
||
# 测试结果:10000 条文本,显存占用 3.5G,速度 3s,显存可能不会自动回收
|
||
|
||
batch_data = batch_data.to(device)
|
||
attention_mask = batch_data["attention_mask"]
|
||
model_output = model(**batch_data)
|
||
last_hidden = model_output.last_hidden_state.masked_fill(
|
||
~attention_mask[..., None].bool(), 0.0
|
||
)
|
||
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||
vector = normalize(
|
||
vector.cpu().detach().numpy(),
|
||
norm="l2",
|
||
axis=1,
|
||
)
|
||
return vector.tolist()
|