Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graphs,plots): expand support for multi-dimensional node attributes #48

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions graphs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.2...HEAD)

### Added

- feat: Support for multi-dimensional node attributes in plots (#48)

## [0.4.2 - Optimisations and lat-lon](https://github.com/ecmwf/anemoi-graphs/compare/0.4.1...0.4.2) - 2024-12-19

### Added
Expand Down
40 changes: 21 additions & 19 deletions graphs/src/anemoi/graphs/plotting/interactive_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
from matplotlib.colors import rgb2hex
from torch_geometric.data import HeteroData

Expand Down Expand Up @@ -197,25 +198,26 @@ def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optiona
for node_attr in node_attrs:
node_attr_values = graph[nodes_name][node_attr].float().numpy()

# Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors
if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1:
continue

node_traces[node_attr] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join(node_attr.split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values.squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"},
"size": 5,
},
visible=False,
)
if node_attr_values.ndim == 1:
node_attr_values = torch.unsqueeze(node_attr_values, -1)

for attr_dim in range(node_attr_values.shape[1]):
suffix = "" if node_attr_values.shape[1] == 1 else f"_[{attr_dim}]"
node_traces[node_attr + suffix] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join((node_attr + suffix).split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values[:, attr_dim].squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr + suffix, "xanchor": "left"},
"size": 5,
},
visible=False,
)

# Create and add slider
slider_steps = []
Expand Down
Loading