fix: cpu模式不支持 float16
This commit is contained in:
@@ -6,10 +6,10 @@ import torch
|
|||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print("Using device:", device)
|
print("Using device:", device)
|
||||||
|
|
||||||
model_name = "models/aspire--acge-large-text"
|
model_name = "models/aspire--acge-large-zh"
|
||||||
print("Loading model", model_name)
|
print("Loading model", model_name)
|
||||||
model = (
|
model = (
|
||||||
AutoModel.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16).eval().to(device)
|
AutoModel.from_pretrained(model_name, local_files_only=True).eval().to(device)
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
|
||||||
print("Model", model_name, "loaded!")
|
print("Model", model_name, "loaded!")
|
||||||
|
|||||||
Reference in New Issue
Block a user