diff --git a/acge_embedding.py b/acge_embedding.py index 39ef787..d191b87 100644 --- a/acge_embedding.py +++ b/acge_embedding.py @@ -1,6 +1,7 @@ from transformers import AutoModel, AutoTokenizer from sklearn.preprocessing import normalize import torch +import torch.nn.functional as F device = "cuda" if torch.cuda.is_available() else "cpu" @@ -43,9 +44,7 @@ def acge_embedding(text: list[str]) -> list[list[float]]: ~attention_mask[..., None].bool(), 0.0 ) vector = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - vector = normalize( - vector.cpu().detach().numpy(), - norm="l2", - axis=1, - ) - return vector.tolist() + # Normalize the output vectors + normalized_vector = F.normalize(vector, p=2, dim=1) + return normalized_vector.tolist() +