From 41c1205dfbc6fb4ccca5bfe5b6a0a15203da0c8f Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 4 Jun 2024 17:57:15 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20cpu=E6=A8=A1=E5=BC=8F=E4=B8=8D=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20float16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- acge_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acge_embedding.py b/acge_embedding.py index 52bfaa6..39ef787 100644 --- a/acge_embedding.py +++ b/acge_embedding.py @@ -6,10 +6,10 @@ import torch device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) -model_name = "models/aspire--acge-large-text" +model_name = "models/aspire--acge-large-zh" print("Loading model", model_name) model = ( - AutoModel.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16).eval().to(device) + AutoModel.from_pretrained(model_name, local_files_only=True).eval().to(device) ) tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) print("Model", model_name, "loaded!")