diff --git a/mlpf/pyg/PFDataset.py b/mlpf/pyg/PFDataset.py index bbea414bd..8cbd2cbcd 100644 --- a/mlpf/pyg/PFDataset.py +++ b/mlpf/pyg/PFDataset.py @@ -115,7 +115,7 @@ def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_3d=Tru self.exclude_keys = exclude_keys self.keys_to_get = keys_to_get self.pad_3d = pad_3d - self.pad_power_of_two = False + self.pad_power_of_two = pad_power_of_two def __call__(self, inputs): num_samples_in_batch = len(inputs)