diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 3fd6550..0b09649 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -47,10 +47,10 @@ def generate_graph(self) -> HeteroData: """ graph = HeteroData() for name, nodes_cfg in self.config.nodes.items(): - graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) + graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {})) for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform( + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph( graph, edges_cfg.get("attributes", {}) ) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 980acf2..3c926d6 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -73,8 +73,8 @@ def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage """Prepare nodes information.""" return graph[self.source_name], graph[self.target_name] - def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: - """Transform the graph. + def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + """Update the graph with the edges. Parameters ---------- diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 5a65746..11e99f6 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -77,10 +77,8 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch coords = np.deg2rad(coords) return torch.tensor(coords, dtype=torch.float32) - def transform(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData: - """Transform the graph. - - It includes nodes to the graph. + def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData: + """Update the graph with new nodes. Parameters ---------- diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py index efe4ee8..a75a1df 100644 --- a/tests/edges/test_cutoff.py +++ b/tests/edges/test_cutoff.py @@ -18,5 +18,5 @@ def test_fail_init(cutoff_factor: str): def test_cutoff(graph_with_nodes): """Test CutOffEdgeBuilder.""" builder = CutOffEdges("test_nodes", "test_nodes", 0.5) - graph = builder.transform(graph_with_nodes) + graph = builder.update_graph(graph_with_nodes) assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py index 7149d0e..f0636a5 100644 --- a/tests/edges/test_knn.py +++ b/tests/edges/test_knn.py @@ -18,5 +18,5 @@ def test_fail_init(num_nearest_neighbours: str): def test_knn(graph_with_nodes): """Test KNNEdgeBuilder.""" builder = KNNEdges("test_nodes", "test_nodes", 3) - graph = builder.transform(graph_with_nodes) + graph = builder.update_graph(graph_with_nodes) assert ("test_nodes", "to", "test_nodes") in graph.edge_types