diff --git a/mace/modules/loss.py b/mace/modules/loss.py index aebae2b4..2d6522d2 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -282,8 +282,8 @@ def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + self.stress_weight * self.huber_loss( - configs_weight * configs_stress_weight * ref["stress"], - configs_weight * configs_stress_weight * pred["stress"], + ref["stress"], + pred["stress"], ) )