From c3044ec2f3416bdec19ea66504b7911549bd3b16 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 29 May 2024 12:55:43 +0200 Subject: [PATCH] Use `HF_HUB_OFFLINE` + fix has_file in offline mode (#31016) * Fix has_file in offline mode * harmonize env variable for offline mode * Switch to HF_HUB_OFFLINE * fix test * revert test_offline to test TRANSFORMERS_OFFLINE * Add new offline test * merge conflicts * docs --- docs/source/de/installation.md | 4 +- docs/source/en/installation.md | 4 +- docs/source/es/installation.md | 4 +- docs/source/fr/installation.md | 4 +- docs/source/it/installation.md | 4 +- docs/source/ja/installation.md | 4 +- docs/source/ko/installation.md | 4 +- docs/source/pt/installation.md | 4 +- docs/source/zh/installation.md | 4 +- src/transformers/modeling_flax_utils.py | 2 + src/transformers/modeling_tf_utils.py | 2 + src/transformers/modeling_utils.py | 5 ++ src/transformers/utils/hub.py | 56 ++++++++++--- tests/test_configuration_utils.py | 1 - tests/utils/test_hub_utils.py | 19 ++++- tests/utils/test_offline.py | 103 +++++++++++++----------- 16 files changed, 148 insertions(+), 76 deletions(-) diff --git a/docs/source/de/installation.md b/docs/source/de/installation.md index 55d0f2d8512d47..1bd34f73302b27 100644 --- a/docs/source/de/installation.md +++ b/docs/source/de/installation.md @@ -162,7 +162,7 @@ Transformers verwendet die Shell-Umgebungsvariablen `PYTORCH_TRANSFORMERS_CACHE` ## Offline Modus -Transformers ist in der Lage, in einer Firewall- oder Offline-Umgebung zu laufen, indem es nur lokale Dateien verwendet. Setzen Sie die Umgebungsvariable `TRANSFORMERS_OFFLINE=1`, um dieses Verhalten zu aktivieren. +Transformers ist in der Lage, in einer Firewall- oder Offline-Umgebung zu laufen, indem es nur lokale Dateien verwendet. Setzen Sie die Umgebungsvariable `HF_HUB_OFFLINE=1`, um dieses Verhalten zu aktivieren. @@ -179,7 +179,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog Führen Sie das gleiche Programm in einer Offline-Instanz mit aus: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md index 7ece8eae44cabd..3ed4edf3d8ec5c 100644 --- a/docs/source/en/installation.md +++ b/docs/source/en/installation.md @@ -169,7 +169,7 @@ Pretrained models are downloaded and locally cached at: `~/.cache/huggingface/hu ## Offline mode -Run 🤗 Transformers in a firewalled or offline environment with locally cached files by setting the environment variable `TRANSFORMERS_OFFLINE=1`. +Run 🤗 Transformers in a firewalled or offline environment with locally cached files by setting the environment variable `HF_HUB_OFFLINE=1`. @@ -178,7 +178,7 @@ Add [🤗 Datasets](https://huggingface.co/docs/datasets/) to your offline train ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/es/installation.md b/docs/source/es/installation.md index b79d0af4a46436..714c3b195ebcc0 100644 --- a/docs/source/es/installation.md +++ b/docs/source/es/installation.md @@ -154,7 +154,7 @@ Los modelos preentrenados se descargan y almacenan en caché localmente en: `~/. ## Modo Offline -🤗 Transformers puede ejecutarse en un entorno con firewall o fuera de línea (offline) usando solo archivos locales. Configura la variable de entorno `TRANSFORMERS_OFFLINE=1` para habilitar este comportamiento. +🤗 Transformers puede ejecutarse en un entorno con firewall o fuera de línea (offline) usando solo archivos locales. Configura la variable de entorno `HF_HUB_OFFLINE=1` para habilitar este comportamiento. @@ -171,7 +171,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog Ejecuta este mismo programa en una instancia offline con el siguiente comando: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/fr/installation.md b/docs/source/fr/installation.md index cd68911bc3564d..bbc93d810f0df1 100644 --- a/docs/source/fr/installation.md +++ b/docs/source/fr/installation.md @@ -171,7 +171,7 @@ Les modèles pré-entraînés sont téléchargés et mis en cache localement dan ## Mode hors ligne -🤗 Transformers peut fonctionner dans un environnement cloisonné ou hors ligne en n'utilisant que des fichiers locaux. Définissez la variable d'environnement `TRANSFORMERS_OFFLINE=1` pour activer ce mode. +🤗 Transformers peut fonctionner dans un environnement cloisonné ou hors ligne en n'utilisant que des fichiers locaux. Définissez la variable d'environnement `HF_HUB_OFFLINE=1` pour activer ce mode. @@ -180,7 +180,7 @@ Ajoutez [🤗 Datasets](https://huggingface.co/docs/datasets/) à votre processu ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/it/installation.md b/docs/source/it/installation.md index 2f45f4182d24c9..a4f444c1eb0c4c 100644 --- a/docs/source/it/installation.md +++ b/docs/source/it/installation.md @@ -152,7 +152,7 @@ I modelli pre-allenati sono scaricati e memorizzati localmente nella cache in: ` ## Modalità Offline -🤗 Transformers può essere eseguita in un ambiente firewalled o offline utilizzando solo file locali. Imposta la variabile d'ambiente `TRANSFORMERS_OFFLINE=1` per abilitare questo comportamento. +🤗 Transformers può essere eseguita in un ambiente firewalled o offline utilizzando solo file locali. Imposta la variabile d'ambiente `HF_HUB_OFFLINE=1` per abilitare questo comportamento. @@ -169,7 +169,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog Esegui lo stesso programma in un'istanza offline con: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/ja/installation.md b/docs/source/ja/installation.md index 915984a91c860e..a0b9dfe3bdbd7a 100644 --- a/docs/source/ja/installation.md +++ b/docs/source/ja/installation.md @@ -157,7 +157,7 @@ conda install conda-forge::transformers ## オフラインモード -🤗 Transformersはローカルファイルのみを使用することでファイアウォールやオフラインの環境でも動作させることができます。この動作を有効にするためには、環境変数`TRANSFORMERS_OFFLINE=1`を設定します。 +🤗 Transformersはローカルファイルのみを使用することでファイアウォールやオフラインの環境でも動作させることができます。この動作を有効にするためには、環境変数`HF_HUB_OFFLINE=1`を設定します。 @@ -174,7 +174,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog オフラインインスタンスでこの同じプログラムを実行します: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md index 062184e5b3ba6c..1583e994d6afe3 100644 --- a/docs/source/ko/installation.md +++ b/docs/source/ko/installation.md @@ -157,7 +157,7 @@ conda install conda-forge::transformers ## 오프라인 모드[[offline-mode]] -🤗 Transformers를 로컬 파일만 사용하도록 해서 방화벽 또는 오프라인 환경에서 실행할 수 있습니다. 활성화하려면 `TRANSFORMERS_OFFLINE=1` 환경 변수를 설정하세요. +🤗 Transformers를 로컬 파일만 사용하도록 해서 방화벽 또는 오프라인 환경에서 실행할 수 있습니다. 활성화하려면 `HF_HUB_OFFLINE=1` 환경 변수를 설정하세요. @@ -174,7 +174,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog 오프라인 기기에서 동일한 프로그램을 다음과 같이 실행할 수 있습니다. ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/pt/installation.md b/docs/source/pt/installation.md index 7eeefd883d6ec3..f548736589ac0d 100644 --- a/docs/source/pt/installation.md +++ b/docs/source/pt/installation.md @@ -173,7 +173,7 @@ No Windows, este diretório pré-definido é dado por `C:\Users\username\.cache\ ## Modo Offline O 🤗 Transformers também pode ser executado num ambiente de firewall ou fora da rede (offline) usando arquivos locais. -Para tal, configure a variável de ambiente de modo que `TRANSFORMERS_OFFLINE=1`. +Para tal, configure a variável de ambiente de modo que `HF_HUB_OFFLINE=1`. @@ -191,7 +191,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog Execute esse mesmo programa numa instância offline com o seguinte comando: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/docs/source/zh/installation.md b/docs/source/zh/installation.md index 91e09dc904bd7e..f87eaa5fc132cf 100644 --- a/docs/source/zh/installation.md +++ b/docs/source/zh/installation.md @@ -169,7 +169,7 @@ conda install conda-forge::transformers ## 离线模式 -🤗 Transformers 可以仅使用本地文件在防火墙或离线环境中运行。设置环境变量 `TRANSFORMERS_OFFLINE=1` 以启用该行为。 +🤗 Transformers 可以仅使用本地文件在防火墙或离线环境中运行。设置环境变量 `HF_HUB_OFFLINE=1` 以启用该行为。 @@ -186,7 +186,7 @@ python examples/pytorch/translation/run_translation.py --model_name_or_path goog 在离线环境中运行相同的程序: ```bash -HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ +HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 \ python examples/pytorch/translation/run_translation.py --model_name_or_path google-t5/t5-small --dataset_name wmt16 --dataset_config ro-en ... ``` diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index f669329ac01bda..61077cf7c30938 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -823,6 +823,8 @@ def from_pretrained( "revision": revision, "proxies": proxies, "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, } if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): is_sharded = True diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f6b9b00117d0a3..0ad5dd0396194a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2864,6 +2864,8 @@ def from_pretrained( "revision": revision, "proxies": proxies, "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, } if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): is_sharded = True diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 27f26e42a84a3b..a613fee62c42ab 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3405,6 +3405,8 @@ def from_pretrained( "revision": revision, "proxies": proxies, "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, } cached_file_kwargs = { "cache_dir": cache_dir, @@ -3432,6 +3434,8 @@ def from_pretrained( "revision": revision, "proxies": proxies, "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, } if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( @@ -3459,6 +3463,7 @@ def from_pretrained( f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." ) + except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted # to the original exception. diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 173fcb352d5f74..efe40f0e21dced 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -51,9 +51,11 @@ GatedRepoError, HFValidationError, LocalEntryNotFoundError, + OfflineModeIsEnabled, RepositoryNotFoundError, RevisionNotFoundError, build_hf_headers, + get_session, hf_raise_for_status, send_telemetry, ) @@ -75,7 +77,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False +_is_offline_mode = huggingface_hub.constants.HF_HUB_OFFLINE def is_offline_mode(): @@ -599,11 +601,17 @@ def has_file( revision: Optional[str] = None, proxies: Optional[Dict[str, str]] = None, token: Optional[Union[bool, str]] = None, + *, + local_files_only: bool = False, + cache_dir: Union[str, Path, None] = None, + repo_type: Optional[str] = None, **deprecated_kwargs, ): """ Checks if a repo contains a given file without downloading it. Works for remote repos and local folders. + If offline mode is enabled, checks if the file exists in the cache. + This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for @@ -621,15 +629,41 @@ def has_file( raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token + # If path to local directory, check if the file exists if os.path.isdir(path_or_repo): return os.path.isfile(os.path.join(path_or_repo, filename)) - url = hf_hub_url(path_or_repo, filename=filename, revision=revision) - headers = build_hf_headers(token=token, user_agent=http_user_agent()) + # Else it's a repo => let's check if the file exists in local cache or on the Hub + + # Check if file exists in cache + # This information might be outdated so it's best to also make a HEAD call (if allowed). + cached_path = try_to_load_from_cache( + repo_id=path_or_repo, + filename=filename, + revision=revision, + repo_type=repo_type, + cache_dir=cache_dir, + ) + has_file_in_cache = isinstance(cached_path, str) + + # If local_files_only, don't try the HEAD call + if local_files_only: + return has_file_in_cache - r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10) + # Check if the file exists try: - hf_raise_for_status(r) + response = get_session().head( + hf_hub_url(path_or_repo, filename=filename, revision=revision, repo_type=repo_type), + headers=build_hf_headers(token=token, user_agent=http_user_agent()), + allow_redirects=False, + proxies=proxies, + timeout=10, + ) + except OfflineModeIsEnabled: + return has_file_in_cache + + try: + hf_raise_for_status(response) return True except GatedRepoError as e: logger.error(e) @@ -640,16 +674,20 @@ def has_file( ) from e except RepositoryNotFoundError as e: logger.error(e) - raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") + raise EnvironmentError( + f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'." + ) from e except RevisionNotFoundError as e: logger.error(e) raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions." - ) + ) from e + except EntryNotFoundError: + return False # File does not exist except requests.HTTPError: - # We return false for EntryNotFoundError (logical) as well as any connection error. - return False + # Any authentication/authorization error will be caught here => default to cache + return has_file_in_cache class PushToHubMixin: diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index a5322a176ec06c..b9f090e061fa72 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import shutil diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py index c1320baaddaff3..aae9bd63cf7c4d 100644 --- a/tests/utils/test_hub_utils.py +++ b/tests/utils/test_hub_utils.py @@ -18,6 +18,7 @@ import unittest.mock as mock from pathlib import Path +from huggingface_hub import hf_hub_download from requests.exceptions import HTTPError from transformers.utils import ( @@ -33,6 +34,7 @@ RANDOM_BERT = "hf-internal-testing/tiny-random-bert" +TINY_BERT_PT_ONLY = "hf-internal-testing/tiny-bert-pt-only" CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert") FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6" @@ -99,9 +101,20 @@ def test_non_existence_is_cached(self): mock_head.assert_called() def test_has_file(self): - self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME)) - self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) - self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME)) + self.assertTrue(has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME)) + self.assertFalse(has_file(TINY_BERT_PT_ONLY, TF2_WEIGHTS_NAME)) + self.assertFalse(has_file(TINY_BERT_PT_ONLY, FLAX_WEIGHTS_NAME)) + + def test_has_file_in_cache(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # Empty cache dir + offline mode => return False + assert not has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir) + + # Populate cache dir + hf_hub_download(TINY_BERT_PT_ONLY, WEIGHTS_NAME, cache_dir=tmp_dir) + + # Cache dir + offline mode => return True + assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir) def test_get_file_from_repo_distant(self): # `get_file_from_repo` returns None if the file does not exist diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py index ecc7938bf3802e..59ed034201a64e 100644 --- a/tests/utils/test_offline.py +++ b/tests/utils/test_offline.py @@ -14,6 +14,7 @@ import subprocess import sys +from typing import Tuple from transformers import BertConfig, BertModel, BertTokenizer, pipeline from transformers.testing_utils import TestCasePlus, require_torch @@ -56,15 +57,9 @@ def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled pipeline(task="fill-mask", model=mname) # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", "\n".join([load, run, mock])] - - # should succeed - env = self.get_env() # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files - env["TRANSFORMERS_OFFLINE"] = "1" - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, run, mock, TRANSFORMERS_OFFLINE="1") + self.assertIn("success", stdout) @require_torch def test_offline_mode_no_internet(self): @@ -97,13 +92,9 @@ def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet") pipeline(task="fill-mask", model=mname) # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", "\n".join([load, run, mock])] - # should succeed - env = self.get_env() - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, run, mock) + self.assertIn("success", stdout) @require_torch def test_offline_mode_sharded_checkpoint(self): @@ -132,27 +123,17 @@ def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled") """ # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", "\n".join([load, run])] - # should succeed - env = self.get_env() - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, run) + self.assertIn("success", stdout) # next emulate no network - cmd = [sys.executable, "-c", "\n".join([load, mock, run])] - # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. - # env["TRANSFORMERS_OFFLINE"] = "0" - # result = subprocess.run(cmd, env=env, check=False, capture_output=True) - # self.assertEqual(result.returncode, 1, result.stderr) + # self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="0") # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files - env["TRANSFORMERS_OFFLINE"] = "1" - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, mock, run, TRANSFORMERS_OFFLINE="1") + self.assertIn("success", stdout) @require_torch def test_offline_mode_pipeline_exception(self): @@ -169,14 +150,11 @@ def test_offline_mode_pipeline_exception(self): def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") socket.socket = offline_socket """ - env = self.get_env() - env["TRANSFORMERS_OFFLINE"] = "1" - cmd = [sys.executable, "-c", "\n".join([load, mock, run])] - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 1, result.stderr) + + _, stderr = self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="1") self.assertIn( "You cannot infer task automatically within `pipeline` when using offline mode", - result.stderr.decode().replace("\n", ""), + stderr.replace("\n", ""), ) @require_torch @@ -191,16 +169,51 @@ def test_offline_model_dynamic_model(self): """ # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", "\n".join([load, run])] - # should succeed - env = self.get_env() - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, run) + self.assertIn("success", stdout) # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files - env["TRANSFORMERS_OFFLINE"] = "1" - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 0, result.stderr) - self.assertIn("success", result.stdout.decode()) + stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1") + self.assertIn("success", stdout) + + def test_is_offline_mode(self): + """ + Test `_is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars) + """ + load = "from transformers.utils import is_offline_mode" + run = "print(is_offline_mode())" + + stdout, _ = self._execute_with_env(load, run) + self.assertIn("False", stdout) + + stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1") + self.assertIn("True", stdout) + + stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1") + self.assertIn("True", stdout) + + def _execute_with_env(self, *commands: Tuple[str, ...], should_fail: bool = False, **env) -> Tuple[str, str]: + """Execute Python code with a given environment and return the stdout/stderr as strings. + + If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed. + Environment variables can be passed as keyword arguments. + """ + # Build command + cmd = [sys.executable, "-c", "\n".join(commands)] + + # Configure env + new_env = self.get_env() + new_env.update(env) + + # Run command + result = subprocess.run(cmd, env=new_env, check=False, capture_output=True) + + # Check execution + if should_fail: + self.assertNotEqual(result.returncode, 0, result.stderr) + else: + self.assertEqual(result.returncode, 0, result.stderr) + + # Return output + return result.stdout.decode(), result.stderr.decode()