fix: cpu模式不支持 float16

This commit is contained in:
2024-06-04 17:57:15 +08:00
parent 947cc86ad8
commit 41c1205dfb

View File

@@ -6,10 +6,10 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device) print("Using device:", device)
model_name = "models/aspire--acge-large-text" model_name = "models/aspire--acge-large-zh"
print("Loading model", model_name) print("Loading model", model_name)
model = ( model = (
AutoModel.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16).eval().to(device) AutoModel.from_pretrained(model_name, local_files_only=True).eval().to(device)
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
print("Model", model_name, "loaded!") print("Model", model_name, "loaded!")