From ab38be683c99578fcadf7c2840ddff8bfa7f1bed Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 1 May 2024 13:36:33 +0200 Subject: [PATCH] Allow computation of force/stress/virial uncertainties --- mace/modules/models.py | 60 ++++++++++-------------- mace/modules/utils.py | 103 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 35 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index 628f88eb..3474c6df 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -1158,6 +1158,18 @@ def __init__( self.covariance_gradients_computed = False self.inv_covariance_computed = False + def aggregate_features(self, ll_feats: torch.Tensor, indices: torch.Tensor, num_graphs: int, num_atoms: torch.Tensor) -> torch.Tensor: + ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1) + ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)] + + # Aggregate node features + ll_feats_cat = torch.cat(ll_feats_list, dim=-1) + ll_feats_agg = scatter_sum( + src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs + ) + + return ll_feats_agg + def forward( self, data: Dict[str, torch.Tensor], @@ -1178,11 +1190,12 @@ def forward( raise RuntimeError("Cannot compute stress uncertainty without computing stress") num_graphs = data["ptr"].numel() - 1 - + num_atoms = data["ptr"][1:] - data["ptr"][:-1] + output = self.orig_model( data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement ) - ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs) + ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs, num_atoms) energy_uncertainty = None force_uncertainty = None @@ -1233,34 +1246,19 @@ def forward( return output - def aggregate_features( - self, - ll_feats: torch.Tensor, - indices: torch.Tensor, - num_graphs: int - ) -> torch.Tensor: - # Aggregates (sums) node features over each structure - ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1) - ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)] - - # Aggregate node features - ll_feats_cat = torch.cat(ll_feats_list, dim=-1) - ll_feats_agg = scatter_sum( - src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs - ) - - return ll_feats_agg - def compute_covariance( self, train_loader: DataLoader, - include_energy: bool = True, include_forces: bool = False, include_virials: bool = False, include_stresses: bool = False, is_universal: bool = False, huber_delta: float = 0.01, ) -> None: + # if not is_universal: + # raise NotImplementedError("Only universal loss models are supported for LLPR") + + import tqdm # Utility function to compute the covariance matrix for a training set. # Note that this function computes the covariance step-wise, so it can # be used to accumulate multiple times on subsets of the same training set @@ -1283,7 +1281,7 @@ def compute_covariance( num_graphs = batch_dict["ptr"].numel() - 1 num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1] - ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs) + ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs, num_atoms) if include_forces or include_virials or include_stresses: f_grads, v_grads, s_grads = compute_ll_feat_gradients( @@ -1309,7 +1307,7 @@ def compute_covariance( ) cur_weights *= huber_mask ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5)) - self.covariance += (ll_feats / num_atoms).T @ (ll_feats / num_atoms) + self.covariance += ll_feats.T @ ll_feats if include_forces: # Account for the weighting of structures and targets @@ -1317,30 +1315,24 @@ def compute_covariance( f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch]) f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch]) cur_f_weights = torch.mul(f_conf_weights, f_forces_weights) - cur_f_weights = cur_f_weights.view(-1, 1).expand(-1, 3) if is_universal: huber_mask_force = get_conditional_huber_force_mask( output["forces"], batch["forces"], huber_delta, ) - cur_f_weights = torch.mul(cur_f_weights, huber_mask_force) - f_grads = torch.mul(f_grads, cur_f_weights.unsqueeze(-1)**(0.5)) + cur_f_weights *= huber_mask_force + f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)**(0.5)) f_grads = f_grads.reshape(-1, ll_feats.shape[-1]) self.covariance += f_grads.T @ f_grads if include_virials: - # No Huber mask in the case of virials as it was not used in the - # universal model cur_v_weights = torch.mul(batch.weight, batch.virials_weight) - cur_v_weights = cur_v_weights.view(-1, 1, 1).expand(-1, 3, 3) - v_grads = torch.mul(v_grads, cur_v_weights.unsqueeze(-1)**(0.5)) - v_grads = v_grads.reshape(-1, ll_feats.shape[-1]) + v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)**(0.5)) + v_grads = v_grads.reshape(-1, ll_feats.shape[-1]) self.covariance += v_grads.T @ v_grads if include_stresses: - # Account for the weighting of structures and targets - # Apply Huber loss mask if universal model cur_s_weights = torch.mul(batch.weight, batch.stress_weight) cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3) if is_universal: @@ -1351,8 +1343,6 @@ def compute_covariance( ) cur_s_weights *= huber_mask_stress s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)**(0.5)) - # The stresses seem to be normalized by n_atoms in the normal loss, but - # not in the universal loss. Here, we don't normalize s_grads = s_grads.reshape(-1, ll_feats.shape[-1]) # The stresses seem to be normalized by n_atoms in the normal loss, but # not in the universal loss. diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 14f6f8b2..fd9e7125 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -70,6 +70,109 @@ def compute_forces_virials( return -1 * forces, -1 * virials, stress +@torch.jit.script +def compute_ll_feat_gradients( + ll_feats: torch.Tensor, + displacement: torch.Tensor, + batch_dict: Dict[str, torch.Tensor], + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(ll_feats[:, 0])] + positions = batch_dict["positions"] + + if compute_force and not (compute_virials or compute_stress): + f_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_f = torch.autograd.grad( + [ll_feats[:, i]], + [positions], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + )[0] + if cur_grad_f is None: + cur_grad_f = torch.zeros_like(positions) + f_grads_list.append(cur_grad_f) + f_grads = torch.stack(f_grads_list) + f_grads = f_grads.permute(1, 2, 0) + v_grads = None + s_grads = None + + elif compute_force and (compute_virials or compute_stress): + cell = batch_dict["cell"] + f_grads_list = [] + v_grads_list = [] + s_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_f, cur_grad_v = torch.autograd.grad( + [ll_feats[:, i]], + [positions, displacement], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + ) + if cur_grad_f is None: + cur_grad_f = torch.zeros_like(positions) + f_grads_list.append(cur_grad_f) + if cur_grad_v is None: + cur_grad_v = torch.zeros_like(displacement) + v_grads_list.append(cur_grad_v) + f_grads = torch.stack(f_grads_list) + f_grads = f_grads.permute(1, 2, 0) # [num_atoms_batch, 3, num_ll_feats] + v_grads = torch.stack(v_grads_list) + v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats] + + if compute_stress: + cell = cell.view(-1, 3, 3) + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + s_grads = v_grads / volume.view(-1, 1, 1, 1) + else: + s_grads = None + + elif not compute_force and (compute_virials or compute_stress): + cell = batch_dict["cell"] + v_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_v = torch.autograd.grad( + [ll_feats[:, i]], + [displacement], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + )[0] + if cur_grad_v is None: + cur_grad_v = torch.zeros_like(displacement) + v_grads_list.append(cur_grad_v) + v_grads = torch.stack(v_grads_list) + v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats] + + if compute_stress: + cell = cell.view(-1, 3, 3) + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + s_grads = v_grads / volume.view(-1, 1, 1, 1) + else: + s_grads = None + f_grads = None + else: + raise RuntimeError("Unsupported configuration for computing gradients") + + return f_grads, v_grads, s_grads + + def get_symmetric_displacement( positions: torch.Tensor, unit_shifts: torch.Tensor,