Skip to content

Commit

Permalink
fix: fix offline (#1341)
Browse files Browse the repository at this point in the history
# What does this PR do?

Allows for Text Generation Inference to succeed in loading prefetched
and cached private models if no token is provided at the time the
text-generation-inference service is launched

Signed-off-by: Raphael Glon <[email protected]>
Co-authored-by: Raphael Glon <[email protected]>
  • Loading branch information
oOraph and oOraph authored Dec 14, 2023
1 parent 44b267a commit 47cd67e
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 33 deletions.
59 changes: 58 additions & 1 deletion server/tests/utils/test_hub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import os
import requests
import tempfile

import pytest

import huggingface_hub.constants
from huggingface_hub import hf_api

import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
Expand All @@ -10,6 +18,52 @@
)


@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value


@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value


@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"]
)
yield model_id


def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")


def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors']


def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
Expand All @@ -33,8 +87,11 @@ def test_download_weights():
assert files == local_files


def test_weight_files_error():
def test_weight_files_revision_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")


def test_weight_files_not_cached_error(fresh_cache):
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")
105 changes: 73 additions & 32 deletions server/text_generation_server/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,29 @@
from pathlib import Path
from typing import Optional, List

from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
)

WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]


def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
return []
filenames = _weight_files_from_dir(d, extension)
return filenames


def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
return [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
Expand All @@ -33,24 +38,26 @@ def weight_hub_files(
and "training" not in s.rfilename
]

if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)

def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d)))
filenames = [f for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f]
return filenames


def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
if revision is None:
revision = "main"

object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))

if not repo_cache.is_dir():
# No cache for this model
Expand All @@ -74,8 +81,42 @@ def try_to_load_from_cache(
# No cache for this revision and we won't try to return a random revision
return None

return snapshots_dir / revision


def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()

if HF_HUB_OFFLINE:
filenames = _cached_weight_files(model_id, revision, extension)
else:
# Online case, fetch model info from the Hub
info = api.model_info(model_id, revision=revision)
filenames = _weight_hub_files_from_model_info(info, extension)

if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)

return filenames


def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""

d = _get_cached_revision_directory(model_id, revision)
if not d:
return None

# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
cached_file = d / filename
return cached_file if cached_file.is_file() else None


Expand Down Expand Up @@ -138,33 +179,33 @@ def download_weights(
) -> List[Path]:
"""Download the safetensors files from the hub"""

def download_file(filename, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, filename)
def download_file(fname, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, fname)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
logger.info(f"File {fname} already present in cache.")
return Path(local_file)

for i in range(tries):
for idx in range(tries):
try:
logger.info(f"Download file: {filename}")
start_time = time.time()
logger.info(f"Download file: {fname}")
stime = time.time()
local_file = hf_hub_download(
filename=filename,
filename=fname,
repo_id=model_id,
revision=revision,
local_files_only=False,
local_files_only=HF_HUB_OFFLINE,
)
logger.info(
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
)
return Path(local_file)
except Exception as e:
if i + 1 == tries:
if idx + 1 == tries:
raise e
logger.error(e)
logger.info(f"Retrying in {backoff} seconds")
time.sleep(backoff)
logger.info(f"Retry {i + 1}/{tries - 1}")
logger.info(f"Retry {idx + 1}/{tries - 1}")

# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
Expand Down

0 comments on commit 47cd67e

Please sign in to comment.