Skip to content

Commit

Permalink
Change test flag to agents
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed May 6, 2024
1 parent 3fc9c56 commit f27439c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule")
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")


Expand Down
12 changes: 6 additions & 6 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def parse_int_from_env(key, default=None):
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)


Expand Down Expand Up @@ -276,19 +276,19 @@ def is_pipeline_test(test_case):
return pytest.mark.is_pipeline_test()(test_case)


def is_tool_test(test_case):
def is_agent_test(test_case):
"""
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
"""
if not _run_tool_tests:
return unittest.skip("test is a tool test")(test_case)
if not _run_agent_tests:
return unittest.skip("test is an agent test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_tool_test()(test_case)
return pytest.mark.is_agent_test()(test_case)


def slow(test_case):
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_tools_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
from transformers.testing_utils import get_tests_dir, is_tool_test
from transformers.testing_utils import get_tests_dir, is_agent_test


if is_torch_available():
Expand Down Expand Up @@ -63,7 +63,7 @@ def output_type(output):
raise ValueError(f"Invalid output: {output}")


@is_tool_test
@is_agent_test
class ToolTesterMixin:
def test_inputs_output(self):
self.assertTrue(hasattr(self.tool, "inputs"))
Expand Down

0 comments on commit f27439c

Please sign in to comment.