Skip to content

Commit

Permalink
divide datasets tests
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 5, 2024
1 parent 2facb94 commit fb01ab2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
13 changes: 11 additions & 2 deletions .github/workflows/test_utils.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[tests] datasets
pip install .[tests]
- name: Test with pytest
working-directory: tests
run: |
python -m pytest -s -vvvv utils
pytest utils -s -n auto -m "not datasets_test" --durations=0
- name: Install datasets
run: |
pip install datasets
- name: Tests needing datasets
working-directory: tests
run: |
pytest utils -s -n auto -m "datasets_test" --durations=0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ markers = [
"rocm_ep_test",
"tensorflow_test",
"timm_test",
"datasets_test",
"run_in_series",
"run_slow",
"accelerate_test",
Expand Down
10 changes: 10 additions & 0 deletions tests/utils/test_task_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union
from unittest import TestCase

import pytest
from datasets import DatasetDict
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer

Expand Down Expand Up @@ -124,6 +125,7 @@ def test_create_defaults_and_kwargs_from_preprocessor_kwargs_does_not_mutate_pre
self.assertDictEqual(preprocessor_kwargs, clone)

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_unallowed_data_keys(self):
task_processor = TaskProcessorsManager.get_task_processor_class_for_task(self.TASK_NAME)(
self.CONFIG, self.PREPROCESSOR
Expand Down Expand Up @@ -191,18 +193,22 @@ def _test_load_dataset(
return dataset

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset(self):
return self._test_load_dataset(False, False, False)

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_by_guessing_data_keys(self):
return self._test_load_dataset(False, True, False)

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_and_only_keep_necessary_columns(self):
return self._test_load_dataset(False, False, True)

@require_datasets
@pytest.mark.datasets_test
def test_load_default_dataset(self):
return self._test_load_dataset(True, False, False)

Expand All @@ -214,6 +220,7 @@ class TextClassificationProcessorTest(TestCase, TaskProcessorTestBase):
WRONG_PREPROCESSOR = IMAGE_PROCESSOR

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_with_max_length(self):
max_length = random.randint(4, 16)
dataset = self._test_load_dataset(False, False, True, max_length=max_length)
Expand All @@ -231,6 +238,7 @@ class TokenClassificationProcessorTest(TestCase, TaskProcessorTestBase):
WRONG_PREPROCESSOR = IMAGE_PROCESSOR

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_with_max_length(self):
max_length = random.randint(4, 16)
dataset = self._test_load_dataset(False, False, True, max_length=max_length)
Expand All @@ -241,6 +249,7 @@ def test_load_dataset_with_max_length(self):
self.assertEqual(len(input_ids), max_length)

@require_datasets
@pytest.mark.datasets_test
def test_load_default_dataset(self):
self.skipTest(
"Skipping so as not to execute conll2003 remote code (test would require trust_remote_code=True)"
Expand All @@ -254,6 +263,7 @@ class QuestionAnsweringProcessorTest(TestCase, TaskProcessorTestBase):
WRONG_PREPROCESSOR = IMAGE_PROCESSOR

@require_datasets
@pytest.mark.datasets_test
def test_load_dataset_with_max_length(self):
max_length = 384
dataset = self._test_load_dataset(False, False, True, max_length=max_length)
Expand Down

0 comments on commit fb01ab2

Please sign in to comment.