init
This commit is contained in:
51
acge_embedding.py
Normal file
51
acge_embedding.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from sklearn.preprocessing import normalize
|
||||
import torch
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Using device:", device)
|
||||
|
||||
model_name = "aspire/acge-large-zh"
|
||||
print("Loading model", model_name)
|
||||
model = (
|
||||
AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).eval().to(device)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
print("Model", model_name, "loaded!")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def acge_embedding(text: list[str]) -> list[list[float]]:
|
||||
# [TODO]: 对于 acge 模型暂定使用 1000 条文本作为上限
|
||||
if len(text) > 1000:
|
||||
raise ValueError("Input text too long!", len(text))
|
||||
|
||||
batch_data = tokenizer(
|
||||
text=text,
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
# max_length=1024,
|
||||
truncation=False,
|
||||
)
|
||||
|
||||
# 检查是否有超长的文本
|
||||
if batch_data["input_ids"].shape[1] > 1024:
|
||||
raise ValueError("Input text too long!", batch_data["input_ids"][0].shape[0])
|
||||
|
||||
# [TODO]: 批次数量太大时,可能会导致显存不足,需要拆分批次处理
|
||||
# 测试结果:10000 条文本,显存占用 3.5G,速度 3s,显存可能不会自动回收
|
||||
|
||||
batch_data = batch_data.to(device)
|
||||
attention_mask = batch_data["attention_mask"]
|
||||
model_output = model(**batch_data)
|
||||
last_hidden = model_output.last_hidden_state.masked_fill(
|
||||
~attention_mask[..., None].bool(), 0.0
|
||||
)
|
||||
vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
vector = normalize(
|
||||
vector.cpu().detach().numpy(),
|
||||
norm="l2",
|
||||
axis=1,
|
||||
)
|
||||
return vector
|
||||
Reference in New Issue
Block a user