diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index 19fa63143..f0ef187c3 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -122,10 +122,6 @@ def __init__( # elementwise DNN for node charge regression, classes (-1, 0, 1) self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout) - def forward_batch(self, batched_events): - batch_or_mask = batched_events.batch if self.conv_type == "gravnet" else batched_events.mask - return self(batched_events.X, batch_or_mask) - def forward(self, X_features, batch_or_mask): embeddings_id, embeddings_reg = [], []