-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Global Encoder-Processor-Decoder graph (#9)
* feat: Initial implementation of global graphs Co-authored-by: Mario Santa Cruz <[email protected]> Co-authored-by: Helen Theissen <[email protected]> Co-authored-by: Sara Hahner <[email protected]> Co-authored-by: Jesper Dramsch <[email protected]>
- Loading branch information
1 parent
a654d21
commit 9231b56
Showing
27 changed files
with
1,542 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,3 +186,6 @@ _build/ | |
*.sync | ||
_version.py | ||
*.code-workspace | ||
|
||
/config* | ||
*.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from ._version import __version__ | ||
|
||
from ._version import __version__ as __version__ | ||
EARTH_RADIUS = 6371.0 # km |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from anemoi.graphs.create import GraphCreator | ||
|
||
from . import Command | ||
|
||
|
||
class Create(Command): | ||
"""Create a graph.""" | ||
|
||
internal = True | ||
timestamp = True | ||
|
||
def add_arguments(self, command_parser): | ||
command_parser.add_argument( | ||
"--overwrite", | ||
action="store_true", | ||
help="Overwrite existing files. This will delete the target graph if it already exists.", | ||
) | ||
command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.") | ||
command_parser.add_argument("path", help="Path to store the created graph.") | ||
|
||
def run(self, args): | ||
kwargs = vars(args) | ||
|
||
c = GraphCreator(**kwargs) | ||
c.create() | ||
|
||
|
||
command = Create |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import logging | ||
import os | ||
|
||
import torch | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from torch_geometric.data import HeteroData | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class GraphCreator: | ||
"""Graph creator.""" | ||
|
||
def __init__( | ||
self, | ||
path, | ||
config=None, | ||
cache=None, | ||
print=print, | ||
overwrite=False, | ||
**kwargs, | ||
): | ||
if isinstance(config, str) or isinstance(config, os.PathLike): | ||
self.config = DotDict.from_file(config) | ||
else: | ||
self.config = config | ||
|
||
self.path = path # Output path | ||
self.cache = cache | ||
self.print = print | ||
self.overwrite = overwrite | ||
|
||
def init(self): | ||
if self._path_readable() and not self.overwrite: | ||
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") | ||
|
||
def generate_graph(self) -> HeteroData: | ||
"""Generate the graph. | ||
It instantiates the node builders and edge builders defined in the configuration | ||
file and applies them to the graph. | ||
Returns | ||
------- | ||
HeteroData: The generated graph. | ||
""" | ||
graph = HeteroData() | ||
for name, nodes_cfg in self.config.nodes.items(): | ||
graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {})) | ||
|
||
for edges_cfg in self.config.edges: | ||
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph( | ||
graph, edges_cfg.get("attributes", {}) | ||
) | ||
|
||
return graph | ||
|
||
def save(self, graph: HeteroData) -> None: | ||
"""Save the graph to the output path.""" | ||
if not os.path.exists(self.path) or self.overwrite: | ||
torch.save(graph, self.path) | ||
self.print(f"Graph saved at {self.path}.") | ||
|
||
def create(self) -> HeteroData: | ||
"""Create the graph and save it to the output path.""" | ||
self.init() | ||
graph = self.generate_graph() | ||
self.save(graph) | ||
return graph | ||
|
||
def _path_readable(self) -> bool: | ||
"""Check if the output path is readable.""" | ||
import torch | ||
|
||
try: | ||
torch.load(self.path) | ||
return True | ||
except FileNotFoundError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .builder import CutOffEdges | ||
from .builder import KNNEdges | ||
|
||
__all__ = ["KNNEdges", "CutOffEdges"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import logging | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.graphs.edges.directional import directional_edge_features | ||
from anemoi.graphs.normalizer import NormalizerMixin | ||
from anemoi.graphs.utils import haversine_distance | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class BaseEdgeAttribute(ABC, NormalizerMixin): | ||
"""Base class for edge attributes.""" | ||
|
||
norm: Optional[str] = None | ||
|
||
@abstractmethod | ||
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... | ||
|
||
def post_process(self, values: np.ndarray) -> torch.Tensor: | ||
"""Post-process the values.""" | ||
if values.ndim == 1: | ||
values = values[:, np.newaxis] | ||
|
||
return torch.tensor(values) | ||
|
||
def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor: | ||
"""Compute the edge attributes.""" | ||
assert ( | ||
source_name in graph.node_types | ||
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." | ||
assert ( | ||
target_name in graph.node_types | ||
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." | ||
|
||
values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) | ||
normed_values = self.normalize(values) | ||
return self.post_process(normed_values) | ||
|
||
|
||
@dataclass | ||
class EdgeDirection(BaseEdgeAttribute): | ||
"""Compute directional features for edges. | ||
If using the rotated features, the direction of the edge is computed | ||
rotating the target nodes to the north pole. If not, it is computed | ||
as the diference in latitude and longitude between the source and | ||
target nodes. | ||
Attributes | ||
---------- | ||
norm : Optional[str] | ||
Normalization method. | ||
luse_rotated_features : bool | ||
Whether to use rotated features. | ||
Methods | ||
------- | ||
get_raw_values(graph, source_name, target_name) | ||
Compute directions between nodes connected by edges. | ||
compute(graph, source_name, target_name) | ||
Compute directional attributes. | ||
""" | ||
|
||
norm: Optional[str] = None | ||
luse_rotated_features: bool = True | ||
|
||
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: | ||
"""Compute directional features for edges. | ||
Parameters | ||
---------- | ||
graph : HeteroData | ||
The graph. | ||
source_name : str | ||
The name of the source nodes. | ||
target_name : str | ||
The name of the target nodes. | ||
Returns | ||
------- | ||
np.ndarray | ||
The directional features. | ||
""" | ||
edge_index = graph[(source_name, "to", target_name)].edge_index | ||
source_coords = graph[source_name].x.numpy()[edge_index[0]].T | ||
target_coords = graph[target_name].x.numpy()[edge_index[1]].T | ||
edge_dirs = directional_edge_features(source_coords, target_coords, self.luse_rotated_features).T | ||
return edge_dirs | ||
|
||
|
||
@dataclass | ||
class EdgeLength(BaseEdgeAttribute): | ||
"""Edge length feature. | ||
Attributes | ||
---------- | ||
norm : str | ||
Normalization method. | ||
invert : bool | ||
Whether to invert the edge lengths, i.e. 1 - edge_length. | ||
Methods | ||
------- | ||
get_raw_values(graph, source_name, target_name) | ||
Compute haversine distance between nodes connected by edges. | ||
compute(graph, source_name, target_name) | ||
Compute edge lengths attributes. | ||
""" | ||
|
||
norm: str = "l1" | ||
invert: bool = True | ||
|
||
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: | ||
"""Compute haversine distance (in kilometers) between nodes connected by edges. | ||
Parameters | ||
---------- | ||
graph : HeteroData | ||
The graph. | ||
source_name : str | ||
The name of the source nodes. | ||
target_name : str | ||
The name of the target nodes. | ||
Returns | ||
------- | ||
np.ndarray | ||
The edge lengths. | ||
""" | ||
edge_index = graph[(source_name, "to", target_name)].edge_index | ||
source_coords = graph[source_name].x.numpy()[edge_index[0]] | ||
target_coords = graph[target_name].x.numpy()[edge_index[1]] | ||
edge_lengths = haversine_distance(source_coords, target_coords) | ||
return edge_lengths | ||
|
||
def post_process(self, values: np.ndarray) -> torch.Tensor: | ||
"""Post-process edge lengths.""" | ||
if self.invert: | ||
values = 1 - values | ||
return super().post_process(values) |
Oops, something went wrong.