Skip to content

Commit

Permalink
Global Encoder-Processor-Decoder graph (#9)
Browse files Browse the repository at this point in the history
* feat: Initial implementation of global graphs

Co-authored-by: Mario Santa Cruz <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
Co-authored-by: Sara Hahner <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>
  • Loading branch information
5 people committed Jul 8, 2024
1 parent a654d21 commit 9231b56
Show file tree
Hide file tree
Showing 27 changed files with 1,542 additions and 39 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ _build/
*.sync
_version.py
*.code-workspace

/config*
*.pt
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ Install via `pip` with:
$ pip install anemoi-graphs
```

## Usage

Create your graph using the configuration given in the config file. The resulting graph will be saved in the given path.

```
$ anemoi-graphs create recipe.yaml my_graph.pt
```

## License

```
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ dynamic = [
"version",
]
dependencies = [
"anemoi-datasets",
"anemoi-datasets[data]>=0.3.3",
"anemoi-utils>=0.3.6",
"hydra-core>=1.3",
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
]

optional-dependencies.all = [
Expand All @@ -59,6 +63,7 @@ optional-dependencies.dev = [
"nbsphinx",
"pandoc",
"pytest",
"pytest-mock",
"requests",
"sphinx",
"sphinx-argparse",
Expand All @@ -80,6 +85,7 @@ optional-dependencies.docs = [

optional-dependencies.tests = [
"pytest",
"pytest-mock",
]

urls.Documentation = "https://anemoi-graphs.readthedocs.io/"
Expand Down
5 changes: 3 additions & 2 deletions src/anemoi/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from ._version import __version__

from ._version import __version__ as __version__
EARTH_RADIUS = 6371.0 # km
28 changes: 28 additions & 0 deletions src/anemoi/graphs/commands/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from anemoi.graphs.create import GraphCreator

from . import Command


class Create(Command):
"""Create a graph."""

internal = True
timestamp = True

def add_arguments(self, command_parser):
command_parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing files. This will delete the target graph if it already exists.",
)
command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.")
command_parser.add_argument("path", help="Path to store the created graph.")

def run(self, args):
kwargs = vars(args)

c = GraphCreator(**kwargs)
c.create()


command = Create
32 changes: 0 additions & 32 deletions src/anemoi/graphs/commands/hello.py

This file was deleted.

80 changes: 80 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
import os

import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class GraphCreator:
"""Graph creator."""

def __init__(
self,
path,
config=None,
cache=None,
print=print,
overwrite=False,
**kwargs,
):
if isinstance(config, str) or isinstance(config, os.PathLike):
self.config = DotDict.from_file(config)
else:
self.config = config

self.path = path # Output path
self.cache = cache
self.print = print
self.overwrite = overwrite

def init(self):
if self._path_readable() and not self.overwrite:
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

def generate_graph(self) -> HeteroData:
"""Generate the graph.
It instantiates the node builders and edge builders defined in the configuration
file and applies them to the graph.
Returns
-------
HeteroData: The generated graph.
"""
graph = HeteroData()
for name, nodes_cfg in self.config.nodes.items():
graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in self.config.edges:
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph(
graph, edges_cfg.get("attributes", {})
)

return graph

def save(self, graph: HeteroData) -> None:
"""Save the graph to the output path."""
if not os.path.exists(self.path) or self.overwrite:
torch.save(graph, self.path)
self.print(f"Graph saved at {self.path}.")

def create(self) -> HeteroData:
"""Create the graph and save it to the output path."""
self.init()
graph = self.generate_graph()
self.save(graph)
return graph

def _path_readable(self) -> bool:
"""Check if the output path is readable."""
import torch

try:
torch.load(self.path)
return True
except FileNotFoundError:
return False
4 changes: 4 additions & 0 deletions src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .builder import CutOffEdges
from .builder import KNNEdges

__all__ = ["KNNEdges", "CutOffEdges"]
148 changes: 148 additions & 0 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.edges.directional import directional_edge_features
from anemoi.graphs.normalizer import NormalizerMixin
from anemoi.graphs.utils import haversine_distance

LOGGER = logging.getLogger(__name__)


@dataclass
class BaseEdgeAttribute(ABC, NormalizerMixin):
"""Base class for edge attributes."""

norm: Optional[str] = None

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
if values.ndim == 1:
values = values[:, np.newaxis]

return torch.tensor(values)

def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
assert (
source_name in graph.node_types
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
assert (
target_name in graph.node_types
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."

values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
normed_values = self.normalize(values)
return self.post_process(normed_values)


@dataclass
class EdgeDirection(BaseEdgeAttribute):
"""Compute directional features for edges.
If using the rotated features, the direction of the edge is computed
rotating the target nodes to the north pole. If not, it is computed
as the diference in latitude and longitude between the source and
target nodes.
Attributes
----------
norm : Optional[str]
Normalization method.
luse_rotated_features : bool
Whether to use rotated features.
Methods
-------
get_raw_values(graph, source_name, target_name)
Compute directions between nodes connected by edges.
compute(graph, source_name, target_name)
Compute directional attributes.
"""

norm: Optional[str] = None
luse_rotated_features: bool = True

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
"""Compute directional features for edges.
Parameters
----------
graph : HeteroData
The graph.
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
Returns
-------
np.ndarray
The directional features.
"""
edge_index = graph[(source_name, "to", target_name)].edge_index
source_coords = graph[source_name].x.numpy()[edge_index[0]].T
target_coords = graph[target_name].x.numpy()[edge_index[1]].T
edge_dirs = directional_edge_features(source_coords, target_coords, self.luse_rotated_features).T
return edge_dirs


@dataclass
class EdgeLength(BaseEdgeAttribute):
"""Edge length feature.
Attributes
----------
norm : str
Normalization method.
invert : bool
Whether to invert the edge lengths, i.e. 1 - edge_length.
Methods
-------
get_raw_values(graph, source_name, target_name)
Compute haversine distance between nodes connected by edges.
compute(graph, source_name, target_name)
Compute edge lengths attributes.
"""

norm: str = "l1"
invert: bool = True

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
"""Compute haversine distance (in kilometers) between nodes connected by edges.
Parameters
----------
graph : HeteroData
The graph.
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
Returns
-------
np.ndarray
The edge lengths.
"""
edge_index = graph[(source_name, "to", target_name)].edge_index
source_coords = graph[source_name].x.numpy()[edge_index[0]]
target_coords = graph[target_name].x.numpy()[edge_index[1]]
edge_lengths = haversine_distance(source_coords, target_coords)
return edge_lengths

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process edge lengths."""
if self.invert:
values = 1 - values
return super().post_process(values)
Loading

0 comments on commit 9231b56

Please sign in to comment.