Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help with Key Error when Running Examples from Readme #256

Open
hohoCode opened this issue Aug 8, 2024 · 3 comments
Open

Help with Key Error when Running Examples from Readme #256

hohoCode opened this issue Aug 8, 2024 · 3 comments

Comments

@hohoCode
Copy link

hohoCode commented Aug 8, 2024

From readme instructions, I tried the listed examples: python gnn_node.py --dataset rel-f1 --task driver-position , and also:
python gnn_node.py --dataset rel-avito --task ad-ctr

My environment has fresh installations of latest relbench (v1.1)/pyg etc.

But both runs give me key error: "Tried to collect 'num_sampled_nodes' but did not find any occurrences of it in any node and/or edge type", as below.

../lib/python3.10/site-packages/relbench/modeling/utils.py:14: FutureWarning: casting datetime64[ns] values to int64 with .astype(...) is deprecated and will raise in a future version. Use .view(...) instead.
  unix_time = ser.astype("int64").values
  0%|                                                                                                                                                                                         | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/yye/relbench/examples/gnn_node.py", line 193, in <module>
    train_loss = train()
  File "/home/yye/relbench/examples/gnn_node.py", line 133, in train
    pred = model(
  File ".../python/torch/2/0/dist/lib/python3.10/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yye/relbench/examples/model.py", line 102, in forward
    batch.num_sampled_nodes_dict,
  File ".../python3.10/site-packages/torch_geometric/data/hetero_data.py", line 161, in __getattr__
    return self.collect(key[:-5])
  File ".../python3.10/site-packages/torch_geometric/data/hetero_data.py", line 565, in collect
    raise KeyError(f"Tried to collect '{key}' but did not find any "
KeyError: "Tried to collect 'num_sampled_nodes' but did not find any occurrences of it in any node and/or edge type"
@rishabh-ranjan
Copy link
Collaborator

I was unable to reproduce this error with a fresh installation. @weihua916
@rusty1s any ideas what might be going on?

@hohoCode
Copy link
Author

What is your pyg version? Mine is the latest version.

@quang-truong
Copy link

quang-truong commented Sep 24, 2024

Have you tried to set subgraph_type=directional for NeighborLoader? Also, the current baseline's GNN doesn't take num_sampled_nodes_dict and num_sampled_edges_dict as arguments, so I think you can safely remove it.

def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[NodeType, Tensor],
        num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
        num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
    ) -> Dict[NodeType, Tensor]:
        for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
            x_dict = {key: x.relu() for key, x in x_dict.items()}

        return x_dict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants