diff --git a/test/NnxMapping.py b/test/NnxMapping.py index 221d35e..4cdbaf0 100644 --- a/test/NnxMapping.py +++ b/test/NnxMapping.py @@ -1,4 +1,5 @@ -from typing import List, Literal, get_args +from enum import Enum +from typing import Dict, NamedTuple, Type from Ne16TestConf import Ne16TestConf from Ne16Weight import Ne16Weight @@ -6,26 +7,21 @@ from NeurekaWeight import NeurekaWeight from NnxTestClasses import NnxTestConf, NnxWeight -NnxName = Literal["ne16", "neureka"] +class NnxName(Enum): + ne16 = "ne16" + neureka = "neureka" -def valid_nnx_names() -> List[str]: - return get_args(NnxName) + def __str__(self): + return self.value -def is_valid_nnx_name(name: str) -> bool: - return name in valid_nnx_names() +class NnxAcceleratorClasses(NamedTuple): + testConfCls: Type[NnxTestConf] + weightCls: Type[NnxWeight] -def NnxWeightClsFromName(name: NnxName) -> NnxWeight: - if name == "ne16": - return Ne16Weight - elif name == "neureka": - return NeurekaWeight - - -def NnxTestConfClsFromName(name: NnxName) -> NnxTestConf: - if name == "ne16": - return Ne16TestConf - elif name == "neureka": - return NeurekaTestConf +NnxMapping: Dict[NnxName, NnxAcceleratorClasses] = { + NnxName.ne16: NnxAcceleratorClasses(Ne16TestConf, Ne16Weight), + NnxName.neureka: NnxAcceleratorClasses(NeurekaTestConf, NeurekaWeight), +} diff --git a/test/TestClasses.py b/test/TestClasses.py index fb84009..e7e7500 100644 --- a/test/TestClasses.py +++ b/test/TestClasses.py @@ -97,30 +97,14 @@ def ctype(self) -> Optional[str]: def __str__(self) -> str: return self.name - def __eq__(self, __value: object) -> bool: - if isinstance(__value, str): - return self.name == __value - elif isinstance(__value, IntegerType): - return self.name == __value.name + def __eq__(self, other: object) -> bool: + if isinstance(other, str): + return self.name == other + elif isinstance(other, IntegerType): + return self.name == other.name else: return False @model_serializer def ser_model(self) -> str: return self.name - - if TYPE_CHECKING: - # Ensure type checkers see the correct return type - def model_dump( - self, - *, - mode: Literal["json", "python"] | str = "python", - include: Any = None, - exclude: Any = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: bool = True, - ) -> dict[str, Any]: ... diff --git a/test/conftest.py b/test/conftest.py index fdb4bb8..ba379ef 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -22,12 +22,7 @@ import pydantic import pytest -from NnxMapping import ( - NnxName, - NnxTestConfClsFromName, - is_valid_nnx_name, - valid_nnx_names, -) +from NnxMapping import NnxMapping, NnxName from NnxTestClasses import NnxTest, NnxTestGenerator from TestClasses import implies @@ -53,8 +48,9 @@ def pytest_addoption(parser): parser.addoption( "-A", "--accelerator", - choices=valid_nnx_names(), - default="ne16", + type=NnxName, + choices=list(NnxName), + default=NnxName.ne16, help="Choose an accelerator to test. Default: ne16", ) parser.addoption( @@ -82,10 +78,6 @@ def pytest_generate_tests(metafunc): timeout = metafunc.config.getoption("timeout") nnxName = metafunc.config.getoption("accelerator") - assert is_valid_nnx_name( - nnxName - ), f"Given accelerator {nnxName} not supported. Supported accelerators: {valid_nnx_names()}" - if recursive: tests_dirs = test_dirs test_dirs = [] @@ -94,7 +86,7 @@ def pytest_generate_tests(metafunc): # Load valid tests nnxTestNames = [] - nnxTestConfCls = NnxTestConfClsFromName(nnxName) + nnxTestConfCls = NnxMapping[nnxName].testConfCls for test_dir in test_dirs: try: test = NnxTest.load(nnxTestConfCls, test_dir) diff --git a/test/test.py b/test/test.py index ccbdad7..9b48f83 100644 --- a/test/test.py +++ b/test/test.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Dict, Optional, Tuple, Type, Union -from NnxMapping import NnxName, NnxTestConfClsFromName, NnxWeightClsFromName +from NnxMapping import NnxMapping, NnxName from NnxTestClasses import NnxTest, NnxTestConf, NnxTestHeaderGenerator, NnxWeight HORIZONTAL_LINE = "\n" + "-" * 100 + "\n" @@ -110,20 +110,20 @@ def assert_message( def test( nnxName: NnxName, - nnxTestName: Tuple[NnxTest, str], + nnxTestName: str, timeout: int, ): - nnxTestConfCls = NnxTestConfClsFromName(nnxName) + testConfCls, weightCls = NnxMapping[nnxName] + # conftest.py makes sure the test is valid and generated - nnxTest = NnxTest.load(nnxTestConfCls, nnxTestName) + nnxTest = NnxTest.load(testConfCls, nnxTestName) - nnxWeightCls = NnxWeightClsFromName(nnxName) - NnxTestHeaderGenerator(nnxWeightCls).generate(nnxTestName, nnxTest) + NnxTestHeaderGenerator(weightCls).generate(nnxTestName, nnxTest) Path("app/src/nnx_layer.c").touch() cmd = f"make -C app all run platform=gvsoc" passed, msg, stdout, stderr = execute_command( - cmd=cmd, timeout=timeout, envflags={"ACCELERATOR": nnxName} + cmd=cmd, timeout=timeout, envflags={"ACCELERATOR": str(nnxName)} ) assert passed, assert_message(msg, nnxTestName, cmd, stdout, stderr)