Also catch client-side network exceptions when synchronizing models (#228)
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -84,16 +85,23 @@ def download_model(
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
except (
|
||||
huggingface_hub.utils.HfHubHTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as exception:
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
||||
repo_id,
|
||||
**kwargs,
|
||||
exception,
|
||||
)
|
||||
except huggingface_hub.utils.HfHubHTTPError:
|
||||
logger.warning(
|
||||
"Trying to load the model directly from the local cache, if it exists."
|
||||
)
|
||||
|
||||
kwargs["local_files_only"] = True
|
||||
return huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
**kwargs,
|
||||
)
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
|
||||
Reference in New Issue
Block a user