diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index d268cb78b7..2104455e0f 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -6,6 +6,7 @@ import logging import os import time +import warnings from http import HTTPStatus from typing import Optional from urllib.parse import urljoin @@ -14,6 +15,7 @@ import requests import tenacity from bs4 import BeautifulSoup +from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME @@ -212,16 +214,21 @@ def download_from_cache_server( download_start = time.time() - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) + # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True + with warnings.catch_warnings(): + if ignore_cert: + warnings.simplefilter('ignore', category=InsecureRequestWarning) + + # Only downloads the blobs in order to avoid downloading model files twice due to the + # symlnks in the Hugging Face cache structure: + _recursive_download( + session, + cache_base_url, + # Trailing slash to indicate directory + f'{formatted_model_name}/blobs/', + save_dir, + ignore_cert=ignore_cert, + ) download_duration = time.time() - download_start log.info( f'Downloaded model {model_name} from cache server in {download_duration} seconds' diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index 6465a552c2..58c3445e7d 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -14,6 +14,8 @@ HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) log = logging.getLogger(__name__) if __name__ == '__main__': @@ -34,7 +36,7 @@ argparser.add_argument( '--fallback', action='store_true', - default=False, + default=True, help= 'Whether to fallback to downloading from Hugging Face if download from cache fails', ) @@ -53,11 +55,25 @@ token=args.token, ignore_cert=args.ignore_cert, ) + + # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. + # This shouldn't actually download any files if the cache server download was successful, but should address + # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. + log.info('Repairing Hugging Face cache symlinks') + + # Hide some noisy logs that aren't important for just the symlink repair. + old_level = logging.getLogger().level + logging.getLogger().setLevel(logging.ERROR) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + logging.getLogger().setLevel(old_level) + except PermissionError: log.error(f'Not authorized to download {args.model}.') except Exception as e: if args.fallback: - log.warn( + log.warning( f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' ) download_from_hf_hub(args.model,