diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index a3326663..b0fc4c61 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -96,7 +96,7 @@ def forward( def _make_tracing_inputs(self, n: int): return [ - {"forward": (torch.randn(6, self.irreps_in.dim), torch.zeros(2))} + {"forward": (torch.randn(6, self.irreps_in.dim), None)} for _ in range(n) ] @@ -142,7 +142,7 @@ def forward( def _make_tracing_inputs(self, n: int): return [ - {"forward": (torch.randn(6, self.irreps_in.dim), torch.zeros(2))} + {"forward": (torch.randn(6, self.irreps_in.dim), None)} for _ in range(n) ] diff --git a/mace/tools/MultKAN_jit.py b/mace/tools/MultKAN_jit.py index 89270cfc..9a51c774 100644 --- a/mace/tools/MultKAN_jit.py +++ b/mace/tools/MultKAN_jit.py @@ -1,3 +1,4 @@ +# pylint: disable=all import os import random diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 0929bdab..54e2a882 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -198,7 +198,7 @@ def radial_to_transform(radial): .non_linearity._modules["acts"][0] .f if model.num_interactions.item() > 1 - and hasattr(model, "KAN_readout") == False + and hasattr(model, "KAN_readout") is False else None ), "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), @@ -223,9 +223,8 @@ def radial_to_transform(radial): def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: return extract_model( - torch.load(f=f, map_location=map_location), + torch.load(f=f, map_location=map_location, pickle_module=dill), map_location=map_location, - pickle_module=dill, )