Skip to content

Commit

Permalink
t push -fMerge pull request ACEsuit#634 from hatemhelal/compile-tests…
Browse files Browse the repository at this point in the history
…-fix

Updating compiler support and test cases
  • Loading branch information
ilyes319 authored and Thomas Warford committed Oct 26, 2024
2 parents 3867c26 + 0f19126 commit 0fbaa95
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
49 changes: 49 additions & 0 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

from abc import abstractmethod

import torch

from mace.tools import TensorDict
Expand Down Expand Up @@ -381,3 +383,50 @@ def __repr__(self):
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})"
)


class RegularizedLoss(torch.nn.Module):
def __init__(self, base_loss: torch.nn.Module, reg_weight=1.0) -> None:
super().__init__()
self.base_loss = base_loss
self.reg_weight = reg_weight

def forward(
self, ref: Batch, pred: TensorDict, model: torch.nn.parameter
) -> torch.Tensor:
base_loss_value = self.base_loss(ref, pred)
reg_term = self.reg_weight * self.compute_regularization(model)
return base_loss_value + reg_term

@abstractmethod
def compute_regularization(
self,
model: torch.nn.module,
) -> torch.Tensor:
return None

def __repr__(self):
return (
f"{self.__class__.__name__}(base_loss={self.base_loss}, "
f"reg_weight={self.reg_weight:.3f})"
)


class L2PairwiseLoss(RegularizedLoss):
def compute_regularization(self, model: torch.nn.module) -> torch.Tensor:
# TODO: add argument to apply regularization for some parts only. Should this function take torch.module or MACE # pylint: disable=fixme

for name, parameters in model.named_parameters():
if name == "fully_connected":
return torch.sum(torch.square(parameters))

return 0.0
# TODO: Check if raising error here is appropriate # pylint: disable=fixme

# do some stuff with indexing!!

def __repr__(self):
return (
f"{self.__class__.__name__} reg_weight={self.reg_weight:.3f}), "
f"base_loss={self.base_loss}"
)
2 changes: 1 addition & 1 deletion mace/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory:
"""
if allow_autograd:
dynamo.allow_in_graph(autograd.grad)
elif dynamo.allowed_functions.is_allowed(autograd.grad):
else:
dynamo.disallow_in_graph(autograd.grad)

@wraps(func)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ dev =
mypy
pre-commit
pytest
pytest-benchmark
pylint
schedulefree = schedulefree
9 changes: 7 additions & 2 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def create_mace(device: str, seed: int = 1702):
"atomic_numbers": table.zs,
"correlation": 3,
"radial_type": "bessel",
"atomic_inter_scale": 1.0,
"atomic_inter_shift": 0.0,
}
model = modules.MACE(**model_config)
model = modules.ScaleShiftMACE(**model_config)
return model.to(device)


Expand Down Expand Up @@ -122,11 +124,14 @@ def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621
@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"])
@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"])
def test_compile_benchmark(benchmark, compile_mode, enable_amp):
if enable_amp:
pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default")

with tools.torch_tools.default_dtype(torch.float32):
batch = create_batch("cuda")
torch.compiler.reset()
model = mace_compile.prepare(create_mace)("cuda")
model = torch.compile(model, mode=compile_mode, fullgraph=True)
model = torch.compile(model, mode=compile_mode)
model = time_func(model)

with torch.autocast("cuda", enabled=enable_amp):
Expand Down

0 comments on commit 0fbaa95

Please sign in to comment.