Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 9, 2024
1 parent 9a6decc commit 4ca717b
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 100 deletions.
1 change: 0 additions & 1 deletion src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
else:
self.config = config


self.path = path # Output path
self.cache = cache
self.print = print
Expand Down
60 changes: 35 additions & 25 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,36 +270,39 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
class TriIcosahedralEdges(BaseEdgeBuilder):
"""Computes icosahedral edges and adds them to a HeteroData graph."""

def __init__(self, src_name: str, xhops: int):
super().__init__(src_name, src_name)
def __init__(self, source_name: str, target_name: str, xhops: int):
super().__init__(source_name, target_name)

assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same."
assert isinstance(xhops, int), "Number of xhops must be an integer"
assert xhops > 0, "Number of xhops must be positive"

self.xhops = xhops

def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:

assert (
graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__
), "IcosahedralConnection requires TriRefinedIcosahedralNodes."
graph[self.source_name].node_type == TriRefinedIcosahedralNodes.__name__
), f"{self.__class__.__name__} requires TriRefinedIcosahedralNodes."

# TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate
# assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes."

return super().transform(graph, attrs_config)
return super().update_graph(graph, attrs_config)

def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):

src_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
src_nodes["nx_graph"],
resolutions=src_nodes["resolutions"],
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
xhops=self.xhops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo")
graph_1_sorted = dict(zip(range(len(src_nodes["nx_graph"].nodes)), list(src_nodes["nx_graph"].nodes)))
graph_2_sorted = dict(zip(src_nodes.node_ordering, range(len(src_nodes.node_ordering))))
adjmat = nx.to_scipy_sparse_array(
source_nodes["nx_graph"], nodelist=list(source_nodes["nx_graph"]), format="coo"
)
graph_1_sorted = dict(zip(range(len(source_nodes["nx_graph"].nodes)), list(source_nodes["nx_graph"].nodes)))
graph_2_sorted = dict(zip(source_nodes.node_ordering, range(len(source_nodes.node_ordering))))
sort_func1 = np.vectorize(graph_1_sorted.get)
sort_func2 = np.vectorize(graph_2_sorted.get)
adjmat.row = sort_func2(sort_func1(adjmat.row))
Expand All @@ -310,35 +313,42 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
class HexagonalEdges(BaseEdgeBuilder):
"""Computes hexagonal edges and adds them to a HeteroData graph."""

def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1):
super().__init__(src_name, src_name)
def __init__(
self,
source_name: str,
target_name: str,
add_neighbouring_children: bool = False,
depth_children: Optional[int] = 1,
):
super().__init__(source_name, source_name)
self.add_neighbouring_children = add_neighbouring_children

assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same."
assert isinstance(depth_children, int), "Depth of children must be an integer"
assert depth_children > 0, "Depth of children must be positive"
self.depth_children = depth_children

def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
assert (
graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__
), "HexagonalEdges requires HexRefinedIcosahedralNodes."
graph[self.source_name].node_type == HexRefinedIcosahedralNodes.__name__
), f"{self.__class__.__name__} requires HexRefinedIcosahedralNodes."

# TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate
# assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes."

return super().transform(graph, attrs_config)
return super().update_graph(graph, attrs_config)

def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):

src_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
src_nodes["nx_graph"],
resolutions=src_nodes["resolutions"],
source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
neighbour_children=self.add_neighbouring_children,
depth_children=self.depth_children,
)

adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], format="coo")
graph_2_sorted = dict(zip(src_nodes["node_ordering"], range(len(src_nodes.node_ordering))))
adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
graph_2_sorted = dict(zip(source_nodes["node_ordering"], range(len(source_nodes.node_ordering))))
sort_func = np.vectorize(graph_2_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
Expand Down
19 changes: 8 additions & 11 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,31 +211,28 @@ class RefinedIcosahedralNodes(BaseNodeBuilder, ABC):
def __init__(
self,
resolution: Union[int, list[int]],
np_dtype: np.dtype = np.float32,
name: str,
) -> None:
# TODO: Discuss np_dtype
self.np_dtype = np_dtype

if isinstance(resolution, int):
self.resolutions = list(range(resolution + 1))
else:
self.resolutions = resolution

super().__init__()
super().__init__(name)

def get_coordinates(self) -> torch.Tensor:
self.nx_graph, coords_rad, self.node_ordering = self.create_nodes()
return torch.tensor(coords_rad[self.node_ordering])
return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32)

@abstractmethod
def create_nodes(self) -> np.ndarray: ...

def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData:
graph[name]["resolutions"] = self.resolutions
graph[name]["nx_graph"] = self.nx_graph
graph[name]["node_ordering"] = self.node_ordering
def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
graph[self.name]["resolutions"] = self.resolutions
graph[self.name]["nx_graph"] = self.nx_graph
graph[self.name]["node_ordering"] = self.node_ordering
# TODO: AOI mask builder is not used in the current implementation.
return super().register_attributes(graph, name, config)
return super().register_attributes(graph, config)


class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes):
Expand Down
20 changes: 0 additions & 20 deletions tests/edges/test_attributes.py

This file was deleted.

19 changes: 12 additions & 7 deletions tests/edges/test_hexagonal_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
class TestTriIcosahedralEdgesInit:
def test_init(self):
"""Test TriIcosahedralEdges initialization."""
assert isinstance(HexagonalEdges("test_nodes"), HexagonalEdges)
assert isinstance(HexagonalEdges("test_nodes", "test_nodes"), HexagonalEdges)

@pytest.mark.parametrize("depth_children", [-0.5, "hello", None, -4])
def test_fail_init(self, depth_children: str):
"""Test HexagonalEdges initialization with invalid cutoff."""
with pytest.raises(AssertionError):
HexagonalEdges("test_nodes", True, depth_children)
HexagonalEdges("test_nodes", "test_nodes", True, depth_children)

def test_fail_init_diff_nodes(self):
"""Test HexagonalEdges initialization with invalid nodes."""
with pytest.raises(AssertionError):
HexagonalEdges("test_nodes", "test_nodes2", 0)


class TestTriIcosahedralEdgesTransform:
Expand All @@ -23,20 +28,20 @@ class TestTriIcosahedralEdgesTransform:
def ico_graph(self) -> HeteroData:
"""Return a HeteroData object with HexRefinedIcosahedralNodes."""
graph = HeteroData()
graph = HexRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {})
graph = HexRefinedIcosahedralNodes(0, "test_nodes").update_graph(graph, {})
graph["fail_nodes"].x = [1, 2, 3]
graph["fail_nodes"].node_type = "FailNodes"
return graph

def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData):
"""Test HexagonalEdges transform method."""

tri_icosahedral_edges = HexagonalEdges("test_nodes")
graph = tri_icosahedral_edges.transform(ico_graph)
tri_icosahedral_edges = HexagonalEdges("test_nodes", "test_nodes")
graph = tri_icosahedral_edges.update_graph(ico_graph)
assert ("test_nodes", "to", "test_nodes") in graph.edge_types

def test_transform_fail_nodes(self, ico_graph: HeteroData):
"""Test HexagonalEdges transform method with wrong node type."""
tri_icosahedral_edges = HexagonalEdges("fail_nodes")
tri_icosahedral_edges = HexagonalEdges("fail_nodes", "fail_nodes")
with pytest.raises(AssertionError):
tri_icosahedral_edges.transform(ico_graph)
tri_icosahedral_edges.update_graph(ico_graph)
25 changes: 15 additions & 10 deletions tests/edges/test_tri_icosahedral_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
class TestTriIcosahedralEdgesInit:
def test_init(self):
"""Test TriIcosahedralEdges initialization."""
assert isinstance(TriIcosahedralEdges("test_nodes", 1), TriIcosahedralEdges)
assert isinstance(TriIcosahedralEdges("test_nodes", "test_nodes", 1), TriIcosahedralEdges)

@pytest.mark.parametrize("xhops", [-0.5, "hello", None, -4])
def test_fail_init(self, xhops: str):
"""Test TriIcosahedralEdges initialization with invalid cutoff."""
"""Test TriIcosahedralEdges initialization with invalid xhops."""
with pytest.raises(AssertionError):
TriIcosahedralEdges("test_nodes", xhops)
TriIcosahedralEdges("test_nodes", "test_nodes", xhops)

def test_fail_init_diff_nodes(self):
"""Test TriIcosahedralEdges initialization with invalid nodes."""
with pytest.raises(AssertionError):
TriIcosahedralEdges("test_nodes", "test_nodes2", 0)


class TestTriIcosahedralEdgesTransform:
Expand All @@ -23,20 +28,20 @@ class TestTriIcosahedralEdgesTransform:
def ico_graph(self) -> HeteroData:
"""Return a HeteroData object with TriRefinedIcosahedralNodes."""
graph = HeteroData()
graph = TriRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {})
graph = TriRefinedIcosahedralNodes(1, "test_nodes").update_graph(graph, {})
graph["fail_nodes"].x = [1, 2, 3]
graph["fail_nodes"].node_type = "FailNodes"
return graph

def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData):
"""Test TriIcosahedralEdges transform method."""
"""Test TriIcosahedralEdges update method."""

tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", 1)
graph = tri_icosahedral_edges.transform(ico_graph)
tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", "test_nodes", 1)
graph = tri_icosahedral_edges.update_graph(ico_graph)
assert ("test_nodes", "to", "test_nodes") in graph.edge_types

def test_transform_fail_nodes(self, ico_graph: HeteroData):
"""Test TriIcosahedralEdges transform method with wrong node type."""
tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", 1)
"""Test TriIcosahedralEdges update method with wrong node type."""
tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", "fail_nodes", 1)
with pytest.raises(AssertionError):
tri_icosahedral_edges.transform(ico_graph)
tri_icosahedral_edges.update_graph(ico_graph)
27 changes: 14 additions & 13 deletions tests/nodes/test_hex_refined_icosahedral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,33 @@
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes import builder
from anemoi.graphs.nodes.builder import BaseNodeBuilder
from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes


@pytest.mark.parametrize("resolution", [0, 2])
def test_init(resolution: list[int]):
"""Test TrirefinedIcosahedralNodes initialization."""

node_builder = builder.HexRefinedIcosahedralNodes(resolution)
assert isinstance(node_builder, builder.BaseNodeBuilder)
assert isinstance(node_builder, builder.HexRefinedIcosahedralNodes)
node_builder = HexRefinedIcosahedralNodes(resolution, "test_nodes")
assert isinstance(node_builder, BaseNodeBuilder)
assert isinstance(node_builder, HexRefinedIcosahedralNodes)


def test_get_coordinates():
"""Test get_coordinates method."""
node_builder = builder.HexRefinedIcosahedralNodes(0)
node_builder = HexRefinedIcosahedralNodes(0, "test_nodes")
coords = node_builder.get_coordinates()
assert isinstance(coords, torch.Tensor)
assert coords.shape == (122, 2)


def test_transform():
"""Test transform method."""
node_builder = builder.HexRefinedIcosahedralNodes(0)
def test_update_graph():
"""Test update_graph method."""
node_builder = HexRefinedIcosahedralNodes(0, "test_nodes")
graph = HeteroData()
graph = node_builder.transform(graph, "test", {})
assert "resolutions" in graph["test"]
assert "nx_graph" in graph["test"]
assert "node_ordering" in graph["test"]
assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes
graph = node_builder.update_graph(graph, {})
assert "resolutions" in graph["test_nodes"]
assert "nx_graph" in graph["test_nodes"]
assert "node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes
27 changes: 14 additions & 13 deletions tests/nodes/test_tri_refined_icosahedral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,33 @@
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes import builder
from anemoi.graphs.nodes.builder import BaseNodeBuilder
from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes


@pytest.mark.parametrize("resolution", [0, 2])
def test_init(resolution: list[int]):
"""Test TrirefinedIcosahedralNodes initialization."""

node_builder = builder.TriRefinedIcosahedralNodes(resolution)
assert isinstance(node_builder, builder.BaseNodeBuilder)
assert isinstance(node_builder, builder.TriRefinedIcosahedralNodes)
node_builder = TriRefinedIcosahedralNodes(resolution, "test_nodes")
assert isinstance(node_builder, BaseNodeBuilder)
assert isinstance(node_builder, TriRefinedIcosahedralNodes)


def test_get_coordinates():
"""Test get_coordinates method."""
node_builder = builder.TriRefinedIcosahedralNodes(2)
node_builder = TriRefinedIcosahedralNodes(2, "test_nodes")
coords = node_builder.get_coordinates()
assert isinstance(coords, torch.Tensor)
assert coords.shape == (162, 2)


def test_transform():
"""Test transform method."""
node_builder = builder.TriRefinedIcosahedralNodes(1)
def test_update_graph():
"""Test update_graph method."""
node_builder = TriRefinedIcosahedralNodes(1, "test_nodes")
graph = HeteroData()
graph = node_builder.transform(graph, "test", {})
assert "resolutions" in graph["test"]
assert "nx_graph" in graph["test"]
assert "node_ordering" in graph["test"]
assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes
graph = node_builder.update_graph(graph, {})
assert "resolutions" in graph["test_nodes"]
assert "nx_graph" in graph["test_nodes"]
assert "node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes

0 comments on commit 4ca717b

Please sign in to comment.