Skip to content

Commit

Permalink
Fix logging verbosity in HF model download script and repair symlinks (
Browse files Browse the repository at this point in the history
…#727)

* Make logs appear and disable InsecureRequestWarning for ignore_cert

* Clean up

* Repair symlinks after cache download

* Clean up logging

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
jerrychen109 and dakinggg authored Nov 10, 2023
1 parent 2f91a64 commit 7c4d24a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
27 changes: 17 additions & 10 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import time
import warnings
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin
Expand All @@ -14,6 +15,7 @@
import requests
import tenacity
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
Expand Down Expand Up @@ -212,16 +214,21 @@ def download_from_cache_server(

download_start = time.time()

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
# Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True
with warnings.catch_warnings():
if ignore_cert:
warnings.simplefilter('ignore', category=InsecureRequestWarning)

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
Expand Down
20 changes: 18 additions & 2 deletions scripts/misc/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN'

logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s',
level=logging.INFO)
log = logging.getLogger(__name__)

if __name__ == '__main__':
Expand All @@ -34,7 +36,7 @@
argparser.add_argument(
'--fallback',
action='store_true',
default=False,
default=True,
help=
'Whether to fallback to downloading from Hugging Face if download from cache fails',
)
Expand All @@ -53,11 +55,25 @@
token=args.token,
ignore_cert=args.ignore_cert,
)

# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure.
# This shouldn't actually download any files if the cache server download was successful, but should address
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized.
log.info('Repairing Hugging Face cache symlinks')

# Hide some noisy logs that aren't important for just the symlink repair.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
logging.getLogger().setLevel(old_level)

except PermissionError:
log.error(f'Not authorized to download {args.model}.')
except Exception as e:
if args.fallback:
log.warn(
log.warning(
f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}'
)
download_from_hf_hub(args.model,
Expand Down

0 comments on commit 7c4d24a

Please sign in to comment.