Skip to content

Commit

Permalink
Update model download utils to support ORAS (#881)
Browse files Browse the repository at this point in the history
* wip

* wip

* Accept registry file for hostname

* Make sure no sensitive info is surfaced in subprocess error

* Refactor model downloading

* Save HF hub files to local dir

* fallback

* Remove commented code

* Update logging

* Update HTP download args

* Use files for ORAS

* Update llmfoundry/utils/model_download_utils.py

Co-authored-by: Irene Dea <[email protected]>

---------

Co-authored-by: Irene Dea <[email protected]>
  • Loading branch information
jerrychen109 and irenedea authored Jan 18, 2024
1 parent 19ee086 commit 35bb339
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 168 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
log_config, pop_config,
update_batch_size_info)
from llmfoundry.utils.model_download_utils import (
download_from_cache_server, download_from_hf_hub)
download_from_hf_hub, download_from_http_fileserver)
except ImportError as e:
raise ImportError(
'Please make sure to pip install . to get requirements for llm-foundry.'
Expand All @@ -28,7 +28,7 @@
'build_tokenizer',
'calculate_batch_size_info',
'convert_and_save_ft_weights',
'download_from_cache_server',
'download_from_http_fileserver',
'download_from_hf_hub',
'get_hf_tokenizer_from_composer_state_dict',
'update_batch_size_info',
Expand Down
144 changes: 98 additions & 46 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import copy
import logging
import os
import shutil
import subprocess
import time
import warnings
from http import HTTPStatus
Expand All @@ -14,6 +16,7 @@
import huggingface_hub as hf_hub
import requests
import tenacity
import yaml
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
Expand All @@ -28,6 +31,9 @@
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'

ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'

log = logging.getLogger(__name__)


Expand All @@ -36,8 +42,8 @@
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(min=1, max=10))
def download_from_hf_hub(
repo_id: str,
save_dir: Optional[str] = None,
model: str,
save_dir: str,
prefer_safetensors: bool = True,
token: Optional[str] = None,
):
Expand All @@ -48,8 +54,7 @@ def download_from_hf_hub(
Args:
repo_id (str): The Hugging Face Hub repo ID.
save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads
from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory.
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
Expand All @@ -59,7 +64,7 @@ def download_from_hf_hub(
RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
ValueError: If the model repo doesn't contain any supported model weights.
"""
repo_files = set(hf_hub.list_repo_files(repo_id))
repo_files = set(hf_hub.list_repo_files(model))

# Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer.
ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)
Expand All @@ -86,18 +91,18 @@ def download_from_hf_hub(
log.info('Only pytorch available. Ignoring weights preference.')
else:
raise ValueError(
f'No supported model weights found in repo {repo_id}.' +
f'No supported model weights found in repo {model}.' +
' Please make sure the repo contains either safetensors or pytorch weights.'
)

download_start = time.time()
hf_hub.snapshot_download(repo_id,
cache_dir=save_dir,
hf_hub.snapshot_download(model,
local_dir=save_dir,
ignore_patterns=ignore_patterns,
token=token)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds'
f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds'
)


Expand Down Expand Up @@ -140,6 +145,7 @@ def _recursive_download(
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""
url = urljoin(base_url, path)
print(url)
response = session.get(url, verify=(not ignore_cert))

if response.status_code == HTTPStatus.UNAUTHORIZED:
Expand All @@ -156,7 +162,7 @@ def _recursive_download(
)

# Assume that the URL points to a file if it does not end with a slash.
if not path.endswith('/'):
if not url.endswith('/'):
save_path = os.path.join(save_dir, path)
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
Expand All @@ -171,6 +177,7 @@ def _recursive_download(
# If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links
# to download.
child_links = _extract_links_from_html(response.content.decode())
print(child_links)
for child_link in child_links:
_recursive_download(session,
base_url,
Expand All @@ -183,53 +190,98 @@ def _recursive_download(
(PermissionError, ValueError)),
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(min=1, max=10))
def download_from_cache_server(
model_name: str,
cache_base_url: str,
def download_from_http_fileserver(
url: str,
save_dir: str,
token: Optional[str] = None,
ignore_cert: bool = False,
):
"""Downloads Hugging Face models from a mirror file server.
The file server is expected to store the files in the same structure as the Hugging Face cache
structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache.
"""Downloads files from a remote HTTP file server.
Args:
model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face
Hub.
cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob
files from `<cache_base_url>/<formatted_model_name>/blobs/`, where `formatted_model_name` is equal to
`models/<model_name>` with all slashes replaced with `--`.
save_dir: The directory to save the downloaded files to.
token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN`
environment variable.
ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to
False.
url (str): The base URL where the files are located.
save_dir (str): The directory to save downloaded files to.
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
Defaults to False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
"""
formatted_model_name = f'models/{model_name}'.replace('/', '--')
with requests.Session() as session:
session.headers.update({'Authorization': f'Bearer {token}'})

download_start = time.time()

# Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True
with warnings.catch_warnings():
if ignore_cert:
warnings.simplefilter('ignore', category=InsecureRequestWarning)

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
_recursive_download(session,
url,
'',
save_dir,
ignore_cert=ignore_cert)


def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Args:
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation '
)

def _read_secrets_file(secret_file_path: str,):
try:
with open(secret_file_path, encoding='utf-8') as f:
return f.read().strip()
except Exception as error:
raise ValueError(
f'secrets file {secret_file_path} failed to be read') from error

secrets = {}
for secret in ['username', 'password', 'registry']:
secrets[secret] = _read_secrets_file(
os.path.join(credentials_dir, secret))

with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs['models'][model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
password: Optional[str] = None):
cmd = [
ORAS_CLI,
'pull',
f'{registry}/{path}',
'-o',
save_dir,
'--verbose',
'--concurrency',
str(concurrency),
]
if username is not None:
cmd.extend(['--username', username])
if password is not None:
cmd.extend(['--password', password])

return cmd

cmd_without_creds = get_oras_cmd()
log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}')
cmd_to_run = get_oras_cmd(username=secrets['username'],
password=secrets['password'])
try:
subprocess.run(cmd_to_run, check=True)
except subprocess.CalledProcessError as e:
# Intercept the error and replace the cmd, which may have sensitive info.
raise subprocess.CalledProcessError(e.returncode, cmd_without_creds,
e.output, e.stderr)
83 changes: 0 additions & 83 deletions scripts/misc/download_hf_model.py

This file was deleted.

Loading

0 comments on commit 35bb339

Please sign in to comment.