From 4959025d9b880130982cc153bcfce3727a5cedb9 Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Wed, 13 Dec 2023 14:26:45 +0100 Subject: [PATCH] Text generation inference, fix offline Signed-off-by: Raphael Glon --- server/tests/utils/test_hub.py | 60 +++++++++++- server/text_generation_server/utils/hub.py | 105 ++++++++++++++------- 2 files changed, 132 insertions(+), 33 deletions(-) diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index fac9a64d565..08a9c3a7516 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,5 +1,13 @@ +import os +import requests +import tempfile + import pytest +import huggingface_hub.constants +from huggingface_hub import hf_api + +import text_generation_server.utils.hub from text_generation_server.utils.hub import ( weight_hub_files, download_weights, @@ -10,6 +18,53 @@ ) +@pytest.fixture() +def offline(): + current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE + text_generation_server.utils.hub.HF_HUB_OFFLINE = True + yield "offline" + text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value + + +@pytest.fixture() +def fresh_cache(): + with tempfile.TemporaryDirectory() as d: + current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d + os.environ['HUGGINGFACE_HUB_CACHE'] = d + yield + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value + os.environ['HUGGINGFACE_HUB_CACHE'] = current_value + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value + + +@pytest.fixture() +def prefetched(): + model_id = "bert-base-uncased" + huggingface_hub.snapshot_download( + repo_id=model_id, + revision="main", + local_files_only=False, + repo_type="model", + allow_patterns=["*.safetensors"] + ) + yield model_id + + +def test_weight_hub_files_offline_error(offline, fresh_cache): + # If the model is not prefetched then it will raise an error + with pytest.raises(EntryNotFoundError): + weight_hub_files("gpt2") + + +# Note: order of fixtures matters here, we need to define the cache location before fetching the model +def test_weight_hub_files_offline_ok(prefetched, offline): + # If the model is prefetched then we should be able to get the weight files from local cache + filenames = weight_hub_files(prefetched) + assert filenames == ['model.safetensors'] + + def test_weight_hub_files(): filenames = weight_hub_files("bigscience/bloom-560m") assert filenames == ["model.safetensors"] @@ -33,8 +88,11 @@ def test_download_weights(): assert files == local_files -def test_weight_files_error(): +def test_weight_files_revision_error(): with pytest.raises(RevisionNotFoundError): weight_files("bigscience/bloom-560m", revision="error") + + +def test_weight_files_not_cached_error(fresh_cache): with pytest.raises(LocalEntryNotFoundError): weight_files("bert-base-uncased") diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 23743c9be0a..b93eeeac174 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -6,24 +6,29 @@ from pathlib import Path from typing import Optional, List -from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, - RevisionNotFoundError, # Import here to ease try/except in other part of the lib + RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) +HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] -def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" -) -> List[str]: - """Get the weights filenames on the hub""" - api = HfApi() - info = api.model_info(model_id, revision=revision) - filenames = [ +def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(model_id, revision) + if not d: + return [] + filenames = _weight_files_from_dir(d, extension) + return filenames + + +def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: + return [ s.rfilename for s in info.siblings if s.rfilename.endswith(extension) @@ -33,24 +38,26 @@ def weight_hub_files( and "training" not in s.rfilename ] - if not filenames: - raise EntryNotFoundError( - f"No {extension} weights found for model {model_id} and revision {revision}.", - None, - ) +def _weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_hub_files_from_model_info, that's also what is + # done there with the len(s.rfilename.split("/")) == 1 condition + root, _, files = next(os.walk(str(d))) + filenames = [f for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f] return filenames -def try_to_load_from_cache( - model_id: str, revision: Optional[str], filename: str -) -> Optional[Path]: - """Try to load a file from the Hugging Face cache""" +def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: if revision is None: revision = "main" - object_id = model_id.replace("/", "--") - repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( + file_download.repo_folder_name(repo_id=model_id, repo_type="model")) if not repo_cache.is_dir(): # No cache for this model @@ -74,8 +81,42 @@ def try_to_load_from_cache( # No cache for this revision and we won't try to return a random revision return None + return snapshots_dir / revision + + +def weight_hub_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[str]: + """Get the weights filenames on the hub""" + api = HfApi() + + if HF_HUB_OFFLINE: + filenames = _cached_weight_files(model_id, revision, extension) + else: + # Online case, fetch model info from the Hub + info = api.model_info(model_id, revision=revision) + filenames = _weight_hub_files_from_model_info(info, extension) + + if not filenames: + raise EntryNotFoundError( + f"No {extension} weights found for model {model_id} and revision {revision}.", + None, + ) + + return filenames + + +def try_to_load_from_cache( + model_id: str, revision: Optional[str], filename: str +) -> Optional[Path]: + """Try to load a file from the Hugging Face cache""" + + d = _get_cached_revision_directory(model_id, revision) + if not d: + return None + # Check if file exists in cache - cached_file = snapshots_dir / revision / filename + cached_file = d / filename return cached_file if cached_file.is_file() else None @@ -138,33 +179,33 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename, tries=5, backoff: int = 5): - local_file = try_to_load_from_cache(model_id, revision, filename) + def download_file(fname, tries=5, backoff: int = 5): + local_file = try_to_load_from_cache(model_id, revision, fname) if local_file is not None: - logger.info(f"File {filename} already present in cache.") + logger.info(f"File {fname} already present in cache.") return Path(local_file) - for i in range(tries): + for idx in range(tries): try: - logger.info(f"Download file: {filename}") - start_time = time.time() + logger.info(f"Download file: {fname}") + stime = time.time() local_file = hf_hub_download( - filename=filename, + filename=fname, repo_id=model_id, revision=revision, - local_files_only=False, + local_files_only=HF_HUB_OFFLINE, ) logger.info( - f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}." ) return Path(local_file) except Exception as e: - if i + 1 == tries: + if idx + 1 == tries: raise e logger.error(e) logger.info(f"Retrying in {backoff} seconds") time.sleep(backoff) - logger.info(f"Retry {i + 1}/{tries - 1}") + logger.info(f"Retry {idx + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher start_time = time.time()