Skip to content

Commit

Permalink
Fix pyright errors and add the NnxMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Dec 12, 2024
1 parent bd79135 commit 6d80c96
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 59 deletions.
32 changes: 14 additions & 18 deletions test/NnxMapping.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
from typing import List, Literal, get_args
from enum import Enum
from typing import Dict, NamedTuple, Type

from Ne16TestConf import Ne16TestConf
from Ne16Weight import Ne16Weight
from NeurekaTestConf import NeurekaTestConf
from NeurekaWeight import NeurekaWeight
from NnxTestClasses import NnxTestConf, NnxWeight

NnxName = Literal["ne16", "neureka"]

class NnxName(Enum):
ne16 = "ne16"
neureka = "neureka"

def valid_nnx_names() -> List[str]:
return get_args(NnxName)
def __str__(self):
return self.value


def is_valid_nnx_name(name: str) -> bool:
return name in valid_nnx_names()
class NnxAcceleratorClasses(NamedTuple):
testConfCls: Type[NnxTestConf]
weightCls: Type[NnxWeight]


def NnxWeightClsFromName(name: NnxName) -> NnxWeight:
if name == "ne16":
return Ne16Weight
elif name == "neureka":
return NeurekaWeight


def NnxTestConfClsFromName(name: NnxName) -> NnxTestConf:
if name == "ne16":
return Ne16TestConf
elif name == "neureka":
return NeurekaTestConf
NnxMapping: Dict[NnxName, NnxAcceleratorClasses] = {
NnxName.ne16: NnxAcceleratorClasses(Ne16TestConf, Ne16Weight),
NnxName.neureka: NnxAcceleratorClasses(NeurekaTestConf, NeurekaWeight),
}
26 changes: 5 additions & 21 deletions test/TestClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,30 +97,14 @@ def ctype(self) -> Optional[str]:
def __str__(self) -> str:
return self.name

def __eq__(self, __value: object) -> bool:
if isinstance(__value, str):
return self.name == __value
elif isinstance(__value, IntegerType):
return self.name == __value.name
def __eq__(self, other: object) -> bool:
if isinstance(other, str):
return self.name == other
elif isinstance(other, IntegerType):
return self.name == other.name
else:
return False

@model_serializer
def ser_model(self) -> str:
return self.name

if TYPE_CHECKING:
# Ensure type checkers see the correct return type
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: Any = None,
exclude: Any = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> dict[str, Any]: ...
18 changes: 5 additions & 13 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
import pydantic
import pytest

from NnxMapping import (
NnxName,
NnxTestConfClsFromName,
is_valid_nnx_name,
valid_nnx_names,
)
from NnxMapping import NnxMapping, NnxName
from NnxTestClasses import NnxTest, NnxTestGenerator
from TestClasses import implies

Expand All @@ -53,8 +48,9 @@ def pytest_addoption(parser):
parser.addoption(
"-A",
"--accelerator",
choices=valid_nnx_names(),
default="ne16",
type=NnxName,
choices=list(NnxName),
default=NnxName.ne16,
help="Choose an accelerator to test. Default: ne16",
)
parser.addoption(
Expand Down Expand Up @@ -82,10 +78,6 @@ def pytest_generate_tests(metafunc):
timeout = metafunc.config.getoption("timeout")
nnxName = metafunc.config.getoption("accelerator")

assert is_valid_nnx_name(
nnxName
), f"Given accelerator {nnxName} not supported. Supported accelerators: {valid_nnx_names()}"

if recursive:
tests_dirs = test_dirs
test_dirs = []
Expand All @@ -94,7 +86,7 @@ def pytest_generate_tests(metafunc):

# Load valid tests
nnxTestNames = []
nnxTestConfCls = NnxTestConfClsFromName(nnxName)
nnxTestConfCls = NnxMapping[nnxName].testConfCls
for test_dir in test_dirs:
try:
test = NnxTest.load(nnxTestConfCls, test_dir)
Expand Down
14 changes: 7 additions & 7 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path
from typing import Dict, Optional, Tuple, Type, Union

from NnxMapping import NnxName, NnxTestConfClsFromName, NnxWeightClsFromName
from NnxMapping import NnxMapping, NnxName
from NnxTestClasses import NnxTest, NnxTestConf, NnxTestHeaderGenerator, NnxWeight

HORIZONTAL_LINE = "\n" + "-" * 100 + "\n"
Expand Down Expand Up @@ -110,20 +110,20 @@ def assert_message(

def test(
nnxName: NnxName,
nnxTestName: Tuple[NnxTest, str],
nnxTestName: str,
timeout: int,
):
nnxTestConfCls = NnxTestConfClsFromName(nnxName)
testConfCls, weightCls = NnxMapping[nnxName]

# conftest.py makes sure the test is valid and generated
nnxTest = NnxTest.load(nnxTestConfCls, nnxTestName)
nnxTest = NnxTest.load(testConfCls, nnxTestName)

nnxWeightCls = NnxWeightClsFromName(nnxName)
NnxTestHeaderGenerator(nnxWeightCls).generate(nnxTestName, nnxTest)
NnxTestHeaderGenerator(weightCls).generate(nnxTestName, nnxTest)

Path("app/src/nnx_layer.c").touch()
cmd = f"make -C app all run platform=gvsoc"
passed, msg, stdout, stderr = execute_command(
cmd=cmd, timeout=timeout, envflags={"ACCELERATOR": nnxName}
cmd=cmd, timeout=timeout, envflags={"ACCELERATOR": str(nnxName)}
)

assert passed, assert_message(msg, nnxTestName, cmd, stdout, stderr)
Expand Down

0 comments on commit 6d80c96

Please sign in to comment.