diff --git a/acge_embedding.py b/acge_embedding.py index 52bfaa6..39ef787 100644 --- a/acge_embedding.py +++ b/acge_embedding.py @@ -6,10 +6,10 @@ import torch device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) -model_name = "models/aspire--acge-large-text" +model_name = "models/aspire--acge-large-zh" print("Loading model", model_name) model = ( - AutoModel.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16).eval().to(device) + AutoModel.from_pretrained(model_name, local_files_only=True).eval().to(device) ) tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) print("Model", model_name, "loaded!")