Skip to content

Commit

Permalink
Include test for healpix
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 9, 2024
1 parent 1c8b158 commit 0235853
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,11 @@ class HEALPixNodes(BaseNodeBuilder):
Update the graph with new nodes and attributes.
"""

def __init__(self, resolution: int) -> None:
def __init__(self, resolution: int, name: str) -> None:
"""Initialize the HEALPixNodes builder."""
self.resolution = resolution
super().__init__(name)

assert isinstance(resolution, int), "Resolution must be an integer."
assert resolution > 0, "Resolution must be positive."

Expand Down
51 changes: 51 additions & 0 deletions tests/nodes/test_healpix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.attributes import AreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
from anemoi.graphs.nodes.builder import BaseNodeBuilder
from anemoi.graphs.nodes.builder import HEALPixNodes


@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_init(resolution: int):
"""Test HEALPixNodes initialization."""
node_builder = HEALPixNodes(resolution, "test_nodes")
assert isinstance(node_builder, BaseNodeBuilder)
assert isinstance(node_builder, HEALPixNodes)


@pytest.mark.parametrize("resolution", ["2", 4.3, -7])
def test_fail_init(resolution: int):
"""Test HEALPixNodes initialization with invalid resolution."""
with pytest.raises(AssertionError):
HEALPixNodes(resolution, "test_nodes")


@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_register_nodes(resolution: int):
"""Test HEALPixNodes register correctly the nodes."""
node_builder = HEALPixNodes(resolution, "test_nodes")
graph = HeteroData()

graph = node_builder.register_nodes(graph)

assert graph["test_nodes"].x is not None
assert isinstance(graph["test_nodes"].x, torch.Tensor)
assert graph["test_nodes"].x.shape[1] == 2
assert graph["test_nodes"].node_type == "HEALPixNodes"


@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
@pytest.mark.parametrize("resolution", [2, 5, 7])
def test_register_attributes(graph_with_nodes: HeteroData, attr_class, resolution: int):
"""Test HEALPixNodes register correctly the weights."""
node_builder = HEALPixNodes(resolution, "test_nodes")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, config)

assert graph["test_nodes"]["test_attr"] is not None
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor)
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0]

0 comments on commit 0235853

Please sign in to comment.