From 3641e41742e06a2bcec9df1421274428ecf8ee3d Mon Sep 17 00:00:00 2001 From: Vladimir Brkic <161027113+vbrkicTT@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:54:16 +0100 Subject: [PATCH] Test suite (#527) --- forge/test/operators/pytorch/test_all.py | 414 ++++++++++++++++++++++ forge/test/operators/utils/__init__.py | 8 +- forge/test/operators/utils/plan.py | 432 ++++++++++++++++------- 3 files changed, 730 insertions(+), 124 deletions(-) create mode 100644 forge/test/operators/pytorch/test_all.py diff --git a/forge/test/operators/pytorch/test_all.py b/forge/test/operators/pytorch/test_all.py new file mode 100644 index 000000000..cc65a30cf --- /dev/null +++ b/forge/test/operators/pytorch/test_all.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +# Tests for testing all pytorch operators + + +# Examples +# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique --collect-only +# TEST_ID='no_device-ge-FROM_HOST-None-(1, 2, 3, 4)-Float16_b-HiFi4' pytest -svv forge/test/operators/pytorch/test_all.py::test_single +# OPERATORS=add,div FILTERS=HAS_DATA_FORMAT,QUICK DEV_DATA_FORMATS=Float16_b,Int8 MATH_FIDELITIES=HiFi4,HiFi3 RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only +# OPERATORS=add,div FAILING_REASONS=DATA_MISMATCH,UNSUPPORTED_DATA_FORMAT SKIP_REASONS=FATAL_ERROR RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only +# FAILING_REASONS=NOT_IMPLEMENTED INPUT_SOURCES=FROM_HOST pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only + +# pytest -svv forge/test/operators/pytorch/test_all.py::test_plan +# pytest -svv forge/test/operators/pytorch/test_all.py::test_failed +# pytest -svv forge/test/operators/pytorch/test_all.py::test_skipped +# pytest -svv forge/test/operators/pytorch/test_all.py::test_fatal +# pytest -svv forge/test/operators/pytorch/test_all.py::test_not_implemented +# pytest -svv forge/test/operators/pytorch/test_all.py::test_data_mismatch +# pytest -svv forge/test/operators/pytorch/test_all.py::test_unsupported_df +# pytest -svv forge/test/operators/pytorch/test_all.py::test_custom + + +import os +import pytest +import forge + +from loguru import logger + +from test.operators.utils import DeviceUtils +from test.operators.utils import InputSource +from test.operators.utils import TestVector +from test.operators.utils import TestCollection +from test.operators.utils import TestCollectionCommon +from test.operators.utils import TestPlanScanner +from test.operators.utils import FailingReasons + + +class TestVerification: + """Helper class for performing test verification. It allows running tests in dry-run mode.""" + + DRY_RUN = False + # DRY_RUN = True + + @classmethod + def verify(cls, test_vector: TestVector, test_device): + if cls.DRY_RUN: + # pytest.skip("Dry run") + return + test_vector.verify(test_device) + + +class TestParamsData: + """ + Helper class for providing data for test parameters. + This is a parameter manager that collects quering criterias from environment variables to determine the tests that should run. It can helper filtering test collection and filtering lambdas. + """ + + __test__ = False # Avoid collecting TestParamsData as a pytest test + + test_suite = TestPlanScanner.build_test_suite(current_directory=os.path.dirname(__file__)) + + @classmethod + def get_single_list(cls) -> list[str]: + """Provide a list of test ids to run for test_single method""" + test_id_single = os.getenv("TEST_ID", None) + return [test_id_single] if test_id_single else [] + + @classmethod + def build_filtered_collection(cls) -> TestCollection: + """ + Builds a filtering test collection based on environment variables + Query criterias are defined by the following environment variables: + - OPERATORS: List of operators to filter + - INPUT_SOURCES: List of input sources to filter + - DEV_DATA_FORMATS: List of data formats to filter + - MATH_FIDELITIES: List of math fidelities to filter + """ + operators = os.getenv("OPERATORS", None) + if operators: + operators = operators.split(",") + else: + operators = None + + input_sources = os.getenv("INPUT_SOURCES", None) + if input_sources: + input_sources = input_sources.split(",") + input_sources = [getattr(InputSource, input_source) for input_source in input_sources] + + # TODO INPUT_SHAPES + + dev_data_formats = os.getenv("DEV_DATA_FORMATS", None) + if dev_data_formats: + dev_data_formats = dev_data_formats.split(",") + dev_data_formats = [getattr(forge.DataFormat, dev_data_format) for dev_data_format in dev_data_formats] + + math_fidelities = os.getenv("MATH_FIDELITIES", None) + if math_fidelities: + math_fidelities = math_fidelities.split(",") + math_fidelities = [getattr(forge.MathFidelity, math_fidelity) for math_fidelity in math_fidelities] + + # TODO KWARGS + + filtered_collection = TestCollection( + operators=operators, + input_sources=input_sources, + dev_data_formats=dev_data_formats, + math_fidelities=math_fidelities, + ) + + return filtered_collection + + def build_filter_lambdas(): + """ + Builds a list of lambdas for filtering test vectors based on environment variables. + The lambdas are built based on the following environment variables: + - FILTERS: List of lambdas defined in VectorLambdas to filter + - FAILING_REASONS: List of failing reasons to filter + - SKIP_REASONS: List of skip reasons to filter + """ + lambdas = [] + + # Include selected filters from VectorLambdas + filters = os.getenv("FILTERS", None) + if filters: + filters = filters.split(",") + filters = [getattr(VectorLambdas, filter) for filter in filters] + lambdas = lambdas + filters + + # TODO: Extend TestCollection with list of failing reasons and skip reasons and move this logic to build_filtered_collection + failing_reasons = os.getenv("FAILING_REASONS", None) + if failing_reasons: + failing_reasons = failing_reasons.split(",") + failing_reasons = [getattr(FailingReasons, failing_reason) for failing_reason in failing_reasons] + + skip_reasons = os.getenv("SKIP_REASONS", None) + if skip_reasons: + skip_reasons = skip_reasons.split(",") + skip_reasons = [getattr(FailingReasons, skip_reason) for skip_reason in skip_reasons] + + if failing_reasons: + lambdas.append( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.failing_reason in failing_reasons + ) + + if skip_reasons: + lambdas.append( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.skip_reason in skip_reasons + ) + + return lambdas + + @classmethod + def get_filter_range(cls) -> tuple[int, int]: + """Provide a range of test vectors to run based on environment variables""" + + range = os.getenv("RANGE", None) + if range: + range = range.split(",") + if len(range) == 1: + return 0, int(range[0]) + else: + return int(range[0]), int(range[1]) + + return 0, 100000 + + +class TestCollectionData: + """Helper test collections""" + + __test__ = False # Avoid collecting TestCollectionData as a pytest test + + # Test collections for query criterias from environment variables + filtered = TestParamsData.build_filtered_collection() + + # All available test vectors + all = TestCollection( + # operators=None, # All available operators + operators=filtered.operators, # Operators selected by filter + ) + + # Quick test collection for faster testing consisting of a subset of input shapes and data formats + quick = TestCollection( + # 2 examples for each dimension and microbatch size + input_shapes=[] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (2,) and shape[0] == 1][:2] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (2,) and shape[0] != 1][:2] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (3,) and shape[0] == 1][:2] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (3,) and shape[0] != 1][:2] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (4,) and shape[0] == 1][:2] + + [shape for shape in TestCollectionCommon.all.input_shapes if len(shape) in (4,) and shape[0] != 1][:2], + # one example for float and int data formats + dev_data_formats=[ + None, + forge.DataFormat.Float16_b, + forge.DataFormat.Int8, + ], + ) + + +class VectorLambdas: + """Helper lambdas for filtering test vectors""" + + ALL_OPERATORS = lambda test_vector: test_vector in TestCollectionData.all + NONE = lambda test_vector: False + + QUICK = lambda test_vector: test_vector in TestCollectionData.quick + FILTERED = lambda test_vector: test_vector in TestCollectionData.filtered + + SINGLE_SHAPE = lambda test_vector: test_vector.input_shape in TestCollectionCommon.single.input_shapes + + SHAPES_2D = lambda test_vector: len(test_vector.input_shape) == 2 + SHAPES_3D = lambda test_vector: len(test_vector.input_shape) == 3 + SHAPES_4D = lambda test_vector: len(test_vector.input_shape) == 4 + + MICROBATCH_SIZE_ONE = lambda test_vector: test_vector.input_shape[0] == 1 + MICROBATCH_SIZE_MULTI = lambda test_vector: test_vector.input_shape[0] > 1 + + FAILING = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.failing_reason is not None + ) + SKIPED = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.skip_reason is not None + ) + SKIPED_FATAL = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.skip_reason == FailingReasons.FATAL_ERROR + ) + NOT_FAILING = ( + lambda test_vector: test_vector.failing_result is None or test_vector.failing_result.failing_reason is None + ) + NOT_SKIPED = ( + lambda test_vector: test_vector.failing_result is None or test_vector.failing_result.skip_reason is None + ) + + HAS_DATA_FORMAT = lambda test_vector: test_vector.dev_data_format is not None + NO_DATA_FORMAT = lambda test_vector: test_vector.dev_data_format is None + + NOT_IMPLEMENTED = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.failing_reason == FailingReasons.NOT_IMPLEMENTED + ) + DATA_MISMATCH = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.failing_reason == FailingReasons.DATA_MISMATCH + ) + UNSUPPORTED_DATA_FORMAT = ( + lambda test_vector: test_vector.failing_result is not None + and test_vector.failing_result.failing_reason == FailingReasons.UNSUPPORTED_DATA_FORMAT + ) + + +test_suite = TestParamsData.test_suite + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all().filter( + VectorLambdas.ALL_OPERATORS, + VectorLambdas.QUICK, + # VectorLambdas.FILTERED, + # VectorLambdas.SINGLE_SHAPE, + # VectorLambdas.HAS_DATA_FORMAT, + # VectorLambdas.NO_DATA_FORMAT, + ) + # .filter(lambda test_vector: test_vector in TestCollection( + # operators=["add", ], + # input_shapes=[ + # (1, 1) + # ], + # failing_reason=FailingReasons.DATA_MISMATCH, + # )) + # .log() + # .filter(lambda test_vector: test_vector.dev_data_format in [forge.DataFormat.Bfp2]) + # .log() + # .skip(lambda test_vector: test_vector.kwargs is not None and "rounding_mode" in test_vector.kwargs and test_vector.kwargs["rounding_mode"] in ["trunc", "floor"]) + # .log() + # .range(0, 10) + # .log() + # .range_skip(2, 5) + # .log() + # .index(3, 5) + # .log() + # .range(0, 10) + # .log() + # .filter(VectorLambdas.NONE) + .to_params(), +) +def test_custom(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.FILTERED) + .filter(*TestParamsData.build_filter_lambdas()) + .range(*TestParamsData.get_filter_range()) + .to_params(), +) +def test_query(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.SINGLE_SHAPE) + .filter( + lambda test_vector: test_vector.input_source in [InputSource.FROM_HOST] + if (TestCollectionData.all.operators is None or len(TestCollectionData.all.operators) > 5) + else True + ) + .group_limit(["operator", "input_source", "kwargs"], 1) + .to_params(), +) +def test_unique(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize("test_vector", test_suite.query_from_id_list(TestParamsData.get_single_list()).to_params()) +def test_single(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize("test_vector", test_suite.query_all().filter(VectorLambdas.ALL_OPERATORS).to_params()) +def test_plan(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +# 1480 passed, 20 xfailed, 2 warnings in 529.46s (0:08:49) +# 4 failed, 1352 passed, 212 xfailed, 115 xpassed, 2 warnings in 590.56s (0:09:50) +# 1 failed, 4041 passed, 20 skipped, 321 xfailed, 2 warnings in 1510.10s (0:25:10) +# 3894 passed, 108 skipped, 444 xfailed, 252 xpassed, 2 warnings in 1719.04s (0:28:39) +# 3834 passed, 60 skipped, 372 xfailed, 252 xpassed, 2 warnings in 1511.94s (0:25:11) +# 10 failed, 3442 passed, 59 skipped, 1030 xfailed, 1 xpassed, 2 warnings in 1787.61s (0:29:47) +# 12 failed, 3443 passed, 59 skipped, 1028 xfailed, 2 warnings in 1716.62s (0:28:36 +# 10 failed, 3443 passed, 59 skipped, 1027 xfailed, 2 warnings in 1819.59s (0:30:19) +# 5 failed, 3443 passed, 59 skipped, 1032 xfailed, 2 warnings in 1715.26s (0:28:35) +# 3443 passed, 59 skipped, 1037 xfailed, 2 warnings in 1726.30s (0:28:46) +# 8 failed, 3432 passed, 59 skipped, 1028 xfailed, 8 xpassed in 1591.84s (0:26:31) +# 3440 passed, 59 skipped, 1035 xfailed in 1587.97s (0:26:27) +# 3500 passed, 1056 xfailed in 1668.66s (0:27:48) +# 4401 passed, 1423 xfailed in 2185.56s (0:36:25) +# 4395 passed, 1429 xfailed in 2577.15s (0:42:57) + + +# Below are examples of custom test functions that utilize filtering lambdas to run specific tests + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED) + .to_params(), +) +def test_failed(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all().filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.SKIPED).to_params(), +) +def test_skipped(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.SKIPED_FATAL) + .to_params(), +) +def test_fatal(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED) + .filter(VectorLambdas.UNSUPPORTED_DATA_FORMAT) + .to_params(), +) +def test_unsupported_df(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED) + .filter(VectorLambdas.NOT_IMPLEMENTED) + .to_params(), +) +def test_not_implemented(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) + + +@pytest.mark.parametrize( + "test_vector", + test_suite.query_all() + .filter(VectorLambdas.ALL_OPERATORS, VectorLambdas.FAILING, VectorLambdas.NOT_SKIPED) + .filter(VectorLambdas.DATA_MISMATCH) + .to_params(), +) +def test_data_mismatch(test_vector: TestVector, test_device): + TestVerification.verify(test_vector, test_device) diff --git a/forge/test/operators/utils/__init__.py b/forge/test/operators/utils/__init__.py index fdbc613f8..b36eedaf2 100644 --- a/forge/test/operators/utils/__init__.py +++ b/forge/test/operators/utils/__init__.py @@ -18,8 +18,10 @@ from .plan import TestCollection from .plan import TestResultFailing from .plan import TestPlan +from .plan import TestSuite +from .plan import TestQuery from .plan import TestPlanUtils -from .plan import TestParamsFilter +from .plan import TestPlanScanner from .test_data import TestCollectionCommon from .failing_reasons import FailingReasons from .failing_reasons import FailingReasonsValidation @@ -44,8 +46,10 @@ "TestCollection", "TestResultFailing", "TestPlan", + "TestSuite", + "TestQuery", "TestPlanUtils", - "TestParamsFilter", + "TestPlanScanner", "TestCollectionCommon", "FailingReasons", "FailingReasonsValidation", diff --git a/forge/test/operators/utils/plan.py b/forge/test/operators/utils/plan.py index 0e2ccb878..3cc737d90 100644 --- a/forge/test/operators/utils/plan.py +++ b/forge/test/operators/utils/plan.py @@ -9,6 +9,12 @@ import forge import re +import os +import importlib +import inspect +from types import ModuleType +from itertools import chain + from _pytest.mark import Mark from _pytest.mark import ParameterSet @@ -22,6 +28,7 @@ from .datatypes import OperatorParameterTypes from .pytest import PytestParamsUtils +from .compat import TestDevice class InputSource(Enum): @@ -81,10 +88,21 @@ class TestVector: kwargs: Optional[OperatorParameterTypes.Kwargs] = None pcc: Optional[float] = None failing_result: Optional[TestResultFailing] = None + test_plan: Optional["TestPlan"] = None # Needed for verification - def get_id(self) -> str: + def get_id(self, fields: Optional[List[str]] = None) -> str: """Get test vector id""" - return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}-{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}" + if fields is None: + return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}-{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}" + else: + attr = [ + (getattr(self, field).name if getattr(self, field) is not None else None) + if field in ("input_source", "dev_data_format", "math_fidelity") + else getattr(self, field) + for field in fields + ] + attr = [str(a) for a in attr] + return "-".join(attr) def get_marks(self) -> List[Mark]: """Get marks for the test vector""" @@ -95,6 +113,10 @@ def to_param(self) -> ParameterSet: """Convert test vector to pytest parameter set""" return pytest.param(self, marks=self.get_marks(), id=self.get_id()) + def verify(self, test_device: "TestDevice"): + """Verify the test vector""" + self.test_plan.verify(test_device=test_device, test_vector=self) + @dataclass class TestCollection: @@ -143,6 +165,150 @@ def __post_init__(self): if self.kwargs is not None and not isinstance(self.kwargs, types.FunctionType): self.kwargs = PytestParamsUtils.strip_param_sets(self.kwargs) + def __contains__(self, item): + if isinstance(item, TestVector): + return TestPlanUtils.test_vector_in_collection(item, self) + raise ValueError(f"Unsupported type: {type(item)} while checking if object is in TestCollection") + + +@dataclass +class TestQuery: + """ + Dataclass for specifying test vectors queries + + Args: + test_vectors: Test vectors + """ + + test_vectors: Generator[TestVector, None, None] + + def _filter_allowed(self, *filters: Callable[[TestVector], bool]) -> Generator[TestVector, None, None]: + for test_vector in self.test_vectors: + if all([filter(test_vector) for filter in filters]): + yield test_vector + + def _filter_skiped(self, *filters: Callable[[TestVector], bool]) -> Generator[TestVector, None, None]: + for test_vector in self.test_vectors: + if any([not filter(test_vector) for filter in filters]): + yield test_vector + + def _filter_indices( + self, indices: Union[int, Tuple[int, int], List[int]] = None, allow_or_skip=True + ) -> Generator[TestVector, None, None]: + index = 0 + for test_vector in self.test_vectors: + found = False + if isinstance(indices, tuple): + # logger.info(f"Tuple type indices: {indices}") + range_min, range_max = indices + if range_min <= index < range_max: + found = True + elif isinstance(indices, list): + # logger.info(f"List type indices: {indices}") + if index in indices: + found = True + else: + logger.error(f"Invalid indices: {indices}") + + index += 1 + if allow_or_skip == found: + yield test_vector + + def _filter_group_limit(self, groups: List[str], limit: int) -> Generator[TestVector, None, None]: + groups_count = {} + for test_vector in self.test_vectors: + test_vector_group = test_vector.get_id(fields=groups) + if test_vector_group not in groups_count: + groups_count[test_vector_group] = 0 + groups_count[test_vector_group] += 1 + if groups_count[test_vector_group] <= limit: + yield test_vector + + def _calculate_failing_result(self) -> Generator[TestVector, None, None]: + for test_vector in self.test_vectors: + test_vector.failing_result = test_vector.test_plan.check_test_failing(test_vector) + yield test_vector + + def _reverse(self) -> Generator[TestVector, None, None]: + test_vectors = list(self.test_vectors) + test_vectors = test_vectors[::-1] + for test_vector in test_vectors: + yield test_vector + + def _log(self) -> Generator[TestVector, None, None]: + test_vectors = list(self.test_vectors) + print("\nParameters:") + for test_vector in test_vectors: + print(f"{test_vector.get_id()}") + yield test_vector + print(f"Count: {len(test_vectors)}\n") + + def filter(self, *filters: Callable[[TestVector], bool]) -> "TestQuery": + """Filter test vectors based on the filter functions""" + return TestQuery(self._filter_allowed(*filters)) + + def skip(self, filters: Callable[[TestVector], bool]) -> "TestQuery": + """Skip test vectors based on the filter functions""" + return TestQuery(self._filter_skiped(*filters)) + + def index(self, *args: int) -> "TestQuery": + """Filter test vectors based on the indices""" + indices = list(args) + return TestQuery(self._filter_indices(indices, allow_or_skip=True)) + + def range(self, start_index: int, end_index: int) -> "TestQuery": + """Filter test vectors based on the range of indices""" + return TestQuery(self._filter_indices((start_index, end_index), allow_or_skip=True)) + + def index_skip(self, *args: int) -> "TestQuery": + """Skip test vectors based on the indices""" + indices = list(args) + return TestQuery(self._filter_indices(indices, allow_or_skip=False)) + + def group_limit(self, groups: List[str], limit: int) -> "TestQuery": + """Limit the number of test vectors per group""" + return TestQuery(self._filter_group_limit(groups, limit)) + + def range_skip(self, start_index: int, end_index: int) -> "TestQuery": + """Skip test vectors based on the range of indices""" + return TestQuery(self._filter_indices((start_index, end_index), allow_or_skip=False)) + + def calculate_failing_result(self) -> "TestQuery": + """Calculate and set the failing result based on the test plan""" + return TestQuery(self._calculate_failing_result()) + + def reverse(self) -> "TestQuery": + """Reverse the order of test vectors""" + return TestQuery(self._reverse()) + + def log(self) -> "TestQuery": + """Log the test vectors""" + return TestQuery(self._log()) + + def to_params(self) -> Generator[ParameterSet, None, None]: + """Convert test vectors to pytest parameter sets""" + test_vectors = self.test_vectors + for test_vector in test_vectors: + yield test_vector.to_param() + + @classmethod + def all(cls, test_plan: Union["TestPlan", "TestSuite"]) -> "TestQuery": + test_vectors = test_plan.generate() + query = TestQuery(test_vectors) + return query.calculate_failing_result() + + @classmethod + def query_from_id_file(cls, test_plan: Union["TestPlan", "TestSuite"], test_ids_file: str) -> "TestQuery": + test_vectors = test_plan.load_test_vectors_from_id_file(test_ids_file) + query = TestQuery(test_vectors) + return query.calculate_failing_result() + + @classmethod + def query_from_id_list(cls, test_plan: Union["TestPlan", "TestSuite"], test_ids: List[str]) -> "TestQuery": + test_vectors = test_plan.load_test_vectors_from_id_list(test_ids) + query = TestQuery(test_vectors) + return query.calculate_failing_result() + @dataclass class TestPlan: @@ -158,8 +324,9 @@ class TestPlan: collections: Optional[List[TestCollection]] = None failing_rules: Optional[List[TestCollection]] = None + verify: Optional[Callable[[TestVector, TestDevice], None]] = None - def _check_test_failing( + def check_test_failing( self, test_vector: TestVector, ) -> Optional[TestResultFailing]: @@ -172,7 +339,7 @@ def _check_test_failing( failing_result = None for failing_rule in self.failing_rules: - if TestPlanUtils.test_vector_in_collection(test_vector, failing_rule): + if test_vector in failing_rule: if failing_rule.failing_reason is not None or failing_rule.skip_reason is not None: failing_result = TestResultFailing(failing_rule.failing_reason, failing_rule.skip_reason) else: @@ -204,7 +371,8 @@ def generate(self) -> Generator[TestVector, None, None]: for dev_data_format in dev_data_formats: for math_fidelity in math_fidelities: - test_vector = TestVector( + test_vector_no_kwargs = TestVector( + test_plan=self, # Set the test plan to support verification operator=input_operator, input_source=input_source, input_shape=input_shape, @@ -214,14 +382,15 @@ def generate(self) -> Generator[TestVector, None, None]: ) # filter collection based on criteria - if test_collection.criteria is None or test_collection.criteria(test_vector): + if test_collection.criteria is None or test_collection.criteria(test_vector_no_kwargs): if isinstance(test_collection.kwargs, types.FunctionType): - kwargs_list = test_collection.kwargs(test_vector) + kwargs_list = test_collection.kwargs(test_vector_no_kwargs) for kwargs in kwargs_list: - # instantiate a new test vector to avoid mutating the original test_vector + # instantiate a new test vector to avoid mutating the original test_vector_no_kwargs test_vector = TestVector( + test_plan=self, # Set the test plan to support verification operator=input_operator, input_source=input_source, input_shape=input_shape, @@ -231,26 +400,82 @@ def generate(self) -> Generator[TestVector, None, None]: kwargs=kwargs, ) - test_vector.failing_result = self._check_test_failing(test_vector) yield test_vector + def load_test_vectors_from_id_file(self, test_ids_file: str) -> List[TestVector]: + test_ids = TestPlanUtils.load_test_ids_from_file(test_ids_file) + + return self.load_test_vectors_from_id_list(test_ids) + + def load_test_vectors_from_id_list(self, test_ids: List[str]) -> List[TestVector]: + test_vectors = TestPlanUtils.test_ids_to_test_vectors(test_ids) + + for test_vector in test_vectors: + if test_vector.operator not in self.collections[0].operators: + raise ValueError(f"Operator {test_vector.operator} not found in test plan") + test_vector.test_plan = self + + return test_vectors + + def query_all(self) -> TestQuery: + return TestQuery.all(self) + + def query_from_id_file(self, test_ids_file: str) -> TestQuery: + return TestQuery.query_from_id_file(self, test_ids_file) + + def query_from_id_list(self, test_ids: List[str]) -> TestQuery: + return TestQuery.query_from_id_list(self, test_ids) + @dataclass -class TestParamsFilter: - """ - Dataclass for specifying test parameters filter +class TestSuite: - Args: - allow: Allow function - indices: Indices to filter - reversed: Reverse the order - log: Log the parameters - """ + __test__ = False # Avoid collecting TestSuite as a pytest test + + test_plans: List[TestPlan] = None + indices: Optional[Dict[str, TestPlan]] = None # TODO remove optional - allow: Optional[Callable[[TestVector], bool]] = lambda test_vector: True - indices: Optional[Union[int, Tuple[int, int], List[int]]] = None - reversed: bool = False - log: bool = False + @staticmethod + def get_test_plan_index(test_plans: List[TestPlan]) -> Dict[str, TestPlan]: + indices = {} + for test_plan in test_plans: + for operator in test_plan.collections[0].operators: + if operator not in indices: + indices[operator] = test_plan + return indices + + def __post_init__(self): + self.indices = self.get_test_plan_index(self.test_plans) + logger.trace(f"Test suite indices: {self.indices.keys()} test_plans: {len(self.test_plans)}") + + def generate(self) -> Generator[TestVector, None, None]: + """Generate test vectors based on the test plan""" + generators = [test_plan.generate() for test_plan in self.test_plans] + return chain(*generators) + + def load_test_vectors_from_id_file(self, test_ids_file: str) -> List[TestVector]: + test_ids = TestPlanUtils.load_test_ids_from_file(test_ids_file) + + return self.load_test_vectors_from_id_list(test_ids) + + def load_test_vectors_from_id_list(self, test_ids: List[str]) -> List[TestVector]: + test_vectors = TestPlanUtils.test_ids_to_test_vectors(test_ids) + + for test_vector in test_vectors: + if test_vector.operator not in self.indices: + raise ValueError(f"Operator {test_vector.operator} not found in test suite") + test_vector.test_plan = self.indices[test_vector.operator] + + return test_vectors + + def query_all(self) -> TestQuery: + return TestQuery.all(self) + + def query_from_id_file(self, test_ids_file: str) -> TestQuery: + return TestQuery.query_from_id_file(self, test_ids_file) + + def query_from_id_list(self, test_ids: List[str]) -> TestQuery: + return TestQuery.query_from_id_list(self, test_ids) class TestPlanUtils: @@ -403,120 +628,83 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector: math_fidelity=math_fidelity, ) - @classmethod - def test_vector_to_test_collection(cls, test_vector: TestVector) -> TestCollection: - - return TestCollection( - operators=[test_vector.operator], - input_sources=[test_vector.input_source], - input_shapes=[test_vector.input_shape], - kwargs=[test_vector.kwargs], - dev_data_formats=[test_vector.dev_data_format], - math_fidelities=[test_vector.math_fidelity], - ) - @classmethod def test_ids_to_test_vectors(cls, test_ids: List[str]) -> List[TestVector]: return [cls.test_id_to_test_vector(test_id) for test_id in test_ids] - @classmethod - def test_vectors_to_test_collections(cls, test_vectors: List[TestVector]) -> List[TestCollection]: - return [cls.test_vector_to_test_collection(test_vector) for test_vector in test_vectors] - @classmethod - def build_test_plan_from_id_list( - cls, test_ids: List[str], test_plan_failing: Optional[TestPlan] = None - ) -> TestPlan: - test_plan = TestPlan( - collections=cls.test_vectors_to_test_collections(cls.test_ids_to_test_vectors(test_ids)), - failing_rules=test_plan_failing.failing_rules if test_plan_failing is not None else [], - ) +class TestPlanScanner: - return test_plan + METHOD_COLLECT_TEST_PLANS = "get_test_plans" @classmethod - def build_test_plan_from_id_file(cls, test_ids_file: str, test_plan_failing: TestPlan) -> TestPlan: - test_ids = cls.load_test_ids_from_file(test_ids_file) - - test_plan = cls.build_test_plan_from_id_list(test_ids, test_plan_failing) - - return test_plan + def find_modules_in_directory(cls, directory: str) -> List[str]: + """Search for all modules in the directory and subdirectories.""" + modules = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".py") and not file.startswith("__init__"): + # Convert file path to Python module path + module_path = os.path.relpath(os.path.join(root, file), directory) + module_name = module_path[:-3].replace(os.sep, ".") + modules.append(module_name) + return modules @classmethod - def generate_params( - cls, test_plan: TestPlan, filter: Optional[TestParamsFilter] = None - ) -> Generator[ParameterSet, None, None]: - test_vectors = test_plan.generate() - - test_vectors = cls.process_filter(test_vectors, filter) - - for test_vector in test_vectors: - yield test_vector.to_param() + def find_and_call_method(cls, module: ModuleType, method_name: str) -> Generator: + """Find and call all method functions.""" + for name, func in inspect.getmembers(module, inspect.isfunction): + if name == method_name and not inspect.isclass(func.__qualname__.split(".")[0]): + logger.trace(f"Calling {method_name} from function: {name} in module: {module.__name__}") + try: + results: List[Union[TestPlan, TestSuite], None, None] = func() # Call the function + for result in results: + yield result + except Exception as e: + logger.error(f"Error calling {name} in {module.__name__}: {e}") + raise e + # return functions_called @classmethod - def yield_test_vectors( - cls, test_vector: Union[TestVector, List[TestVector], Generator[TestVector, None, None]] - ) -> Generator[TestVector, None, None]: - if test_vector is None: - pass - elif isinstance(test_vector, TestVector): - yield test_vector - elif isinstance(test_vector, types.GeneratorType): - return test_vector - elif isinstance(test_vector, list): - for item in test_vector: - yield item + def scan_and_invoke(cls, directory: str, method_name: str) -> Generator: + """Scan the directory and invoke all method functions.""" + modules = cls.find_modules_in_directory(directory) + + for module_name in modules: + try: + logger.trace(f"Loading module: {module_name}") + # Dynamic module loading + module = importlib.import_module(module_name) + results = cls.find_and_call_method(module, method_name) + for result in results: + yield result + except Exception as e: + logger.error(f"Problem loading module {module_name}: {e}") + raise e @classmethod - def filter_allowed( - cls, test_params: Generator[TestVector, None, None], filter: TestParamsFilter - ) -> Generator[TestVector, None, None]: - index = 0 - for p in test_params: - allowed = False - if filter.allow is None: - allowed = True - elif filter.allow(p): - if filter.indices is None: - allowed = True - else: - if isinstance(filter.indices, int): - # logger.info(f"Int type filter.indices: {filter.indices}") - if filter.indices == index: - allowed = True - elif isinstance(filter.indices, tuple): - # logger.info(f"Tuple type filter.indices: {filter.indices}") - range_min, range_max = filter.indices - if range_min <= index <= range_max: - allowed = True - elif isinstance(filter.indices, list): - # logger.info(f"List type filter.indices: {filter.indices}") - if index in filter.indices: - allowed = True - else: - logger.error(f"Invalid filter.indices: {filter.indices}") - - index += 1 - if allowed: - yield p + def collect_test_plans(cls, result: Union[TestPlan, TestSuite]) -> Generator[TestPlan, None, None]: + if isinstance(result, TestSuite): + test_suite = result + for test_plan in test_suite.test_plans: + yield test_plan + elif isinstance(result, TestPlan): + test_plan = result + yield test_plan + else: + raise ValueError(f"Unsupported suite/plan type: {type(result)}") @classmethod - def process_filter( - cls, test_params: Generator[TestVector, None, None], filter: Optional[TestParamsFilter] = None - ) -> Generator[TestVector, None, None]: - if filter is not None: - test_params = cls.filter_allowed(test_params, filter) - - if filter.reversed == True: + def get_all_test_plans(cls, current_directory: str) -> Generator[TestPlan, None, None]: + """Get all test suites from the current directory.""" + results = cls.scan_and_invoke(current_directory, cls.METHOD_COLLECT_TEST_PLANS) + for result in results: + for test_plan in cls.collect_test_plans(result): + yield test_plan + return results - if not isinstance(test_params, list): - test_params = list(test_params) - - test_params = test_params[::-1] - - if filter is not None and filter.log == True: - logger.info("Parameters:") - for p in test_params: - if filter is not None and filter.log == True: - logger.info(f"{p.get_id()}") - yield p + @classmethod + def build_test_suite(cls, current_directory: str) -> TestSuite: + test_plans = TestPlanScanner.get_all_test_plans(current_directory) + test_plans = list(test_plans) + return TestSuite(test_plans=test_plans)