Skip to content

Commit

Permalink
add return type hints to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Nov 19, 2024
1 parent 991cf0d commit c3056c7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,38 @@
)
class TestMaskReduce:
@pytest.fixture()
def mask1(self):
def mask1(self) -> Tensor:
return torch.tensor([[0, 0], [1, 2]], dtype=torch.int64)

@pytest.fixture()
def mask2(self):
def mask2(self) -> Tensor:
return torch.tensor([[1, 0], [0, 1]], dtype=torch.int64)

@pytest.fixture()
def feature_map1(self):
def feature_map1(self) -> Tensor:
feature_map = torch.tensor(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def feature_map2(self):
def feature_map2(self) -> Tensor:
feature_map = torch.tensor(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def expected_result1(self):
def expected_result1(self) -> Tensor:
res = torch.tensor(
[[0.5, 2.0, 3.0], [4.5, 6.0, 7.0], [8.5, 10.0, 11.0]], dtype=torch.float32
)
return res

@pytest.fixture()
def expected_result2(self):
def expected_result2(self) -> Tensor:
res = torch.tensor(
[[2.5, 2.5, 0.0], [6.5, 6.5, 0.0], [10.5, 10.5, 0.0]], dtype=torch.float32
)
Expand Down

0 comments on commit c3056c7

Please sign in to comment.