diff --git a/tests/__init__.py b/tests/__init__.py index b32609cd..4547b94e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,18 @@ from pathlib import Path +from datasets import DownloadMode, load_dataset + TESTS_ROOT = Path(__file__).parent FIXTURES_ROOT = TESTS_ROOT / "fixtures" DATASET_BUILDERS_ROOT = Path("dataset_builders") + + +def _check_hf_conll2003_is_available(): + try: + load_dataset("conll2003", download_mode=DownloadMode.FORCE_REDOWNLOAD) + return True + except ConnectionError: + return False + + +_HF_CONLL2003_IS_AVAILABLE = _check_hf_conll2003_is_available() diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 7db0f3b8..2d390e2b 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -18,7 +18,7 @@ from pytorch_ie.data.dataset import get_pie_dataset_type from pytorch_ie.documents import TextDocument from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule -from tests import DATASET_BUILDERS_ROOT +from tests import _HF_CONLL2003_IS_AVAILABLE, DATASET_BUILDERS_ROOT from tests.conftest import TestDocument @@ -265,6 +265,10 @@ def test_dataset_with_taskmodule( assert not document["entities"].predictions +@pytest.mark.skipif( + not _HF_CONLL2003_IS_AVAILABLE, + reason="the Huggingface conll2003 dataset is not reachable and the local PIE-variant depends on it", +) def test_load_with_hf_datasets(): dataset_path = DATASET_BUILDERS_ROOT / "conll2003" @@ -279,6 +283,10 @@ def test_load_with_hf_datasets(): assert len(dataset["test"]) == 3453 +@pytest.mark.skipif( + not _HF_CONLL2003_IS_AVAILABLE, + reason="the Huggingface conll2003 dataset is not reachable and the remote PIE-variant depends on it", +) def test_load_with_hf_datasets_from_hub(): dataset = datasets.load_dataset( path="pie/conll2003",