Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 9, 2024
1 parent 92e7dfe commit 77f150d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/graphs/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_node_summary(self) -> list[list]:
name,
number(nodes.num_nodes),
", ".join(attributes),
sum(nodes[attr].shape[1] for attr in attributes),
sum(nodes[attr].shape[1] for attr in attributes if isinstance(nodes[attr], torch.Tensor)),
number(nodes.x[:, 0].min().item() / 2 / math.pi * 360),
number(nodes.x[:, 0].max().item() / 2 / math.pi * 360),
number(nodes.x[:, 1].min().item() / 2 / math.pi * 360),
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/graphs/plotting/prepare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import numpy as np
import torch
from torch_geometric.data import HeteroData


Expand Down Expand Up @@ -108,7 +109,7 @@ def _get_node_attribute_dims(graph: HeteroData) -> dict[str, int]:
attr_dims = {}
for nodes in graph.node_stores:
for attr in nodes.node_attrs():
if attr == "x":
if attr == "x" and not isinstance(nodes[attr], torch.Tensor):
continue
elif attr not in attr_dims:
attr_dims[attr] = nodes[attr].shape[1]
Expand Down

0 comments on commit 77f150d

Please sign in to comment.