Skip to content

Commit

Permalink
transform() -> update_graph()
Browse files Browse the repository at this point in the history
Co-authored-by: <[email protected]>
Co-authored-by: <[email protected]>
  • Loading branch information
JPXKQX committed Jul 5, 2024
1 parent 12b1b3c commit 64d5178
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def generate_graph(self) -> HeteroData:
"""
graph = HeteroData()
for name, nodes_cfg in self.config.nodes.items():
graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {}))
graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in self.config.edges:
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph(
graph, edges_cfg.get("attributes", {})
)

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage
"""Prepare nodes information."""
return graph[self.source_name], graph[self.target_name]

def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
"""Transform the graph.
def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
"""Update the graph with the edges.
Parameters
----------
Expand Down
6 changes: 2 additions & 4 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch
coords = np.deg2rad(coords)
return torch.tensor(coords, dtype=torch.float32)

def transform(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData:
"""Transform the graph.
It includes nodes to the graph.
def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData:
"""Update the graph with new nodes.
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion tests/edges/test_cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def test_fail_init(cutoff_factor: str):
def test_cutoff(graph_with_nodes):
"""Test CutOffEdgeBuilder."""
builder = CutOffEdges("test_nodes", "test_nodes", 0.5)
graph = builder.transform(graph_with_nodes)
graph = builder.update_graph(graph_with_nodes)
assert ("test_nodes", "to", "test_nodes") in graph.edge_types
2 changes: 1 addition & 1 deletion tests/edges/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def test_fail_init(num_nearest_neighbours: str):
def test_knn(graph_with_nodes):
"""Test KNNEdgeBuilder."""
builder = KNNEdges("test_nodes", "test_nodes", 3)
graph = builder.transform(graph_with_nodes)
graph = builder.update_graph(graph_with_nodes)
assert ("test_nodes", "to", "test_nodes") in graph.edge_types

0 comments on commit 64d5178

Please sign in to comment.