Skip to content

Commit

Permalink
Allow computation of force/stress/virial uncertainties
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Sep 23, 2024
1 parent c1a1b31 commit ab38be6
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 35 deletions.
60 changes: 25 additions & 35 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,18 @@ def __init__(
self.covariance_gradients_computed = False
self.inv_covariance_computed = False

def aggregate_features(self, ll_feats: torch.Tensor, indices: torch.Tensor, num_graphs: int, num_atoms: torch.Tensor) -> torch.Tensor:
ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1)
ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)]

# Aggregate node features
ll_feats_cat = torch.cat(ll_feats_list, dim=-1)
ll_feats_agg = scatter_sum(
src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs
)

return ll_feats_agg

def forward(
self,
data: Dict[str, torch.Tensor],
Expand All @@ -1178,11 +1190,12 @@ def forward(
raise RuntimeError("Cannot compute stress uncertainty without computing stress")

num_graphs = data["ptr"].numel() - 1

num_atoms = data["ptr"][1:] - data["ptr"][:-1]

output = self.orig_model(
data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement
)
ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs)
ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs, num_atoms)

energy_uncertainty = None
force_uncertainty = None
Expand Down Expand Up @@ -1233,34 +1246,19 @@ def forward(

return output

def aggregate_features(
self,
ll_feats: torch.Tensor,
indices: torch.Tensor,
num_graphs: int
) -> torch.Tensor:
# Aggregates (sums) node features over each structure
ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1)
ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)]

# Aggregate node features
ll_feats_cat = torch.cat(ll_feats_list, dim=-1)
ll_feats_agg = scatter_sum(
src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs
)

return ll_feats_agg

def compute_covariance(
self,
train_loader: DataLoader,
include_energy: bool = True,
include_forces: bool = False,
include_virials: bool = False,
include_stresses: bool = False,
is_universal: bool = False,
huber_delta: float = 0.01,
) -> None:
# if not is_universal:
# raise NotImplementedError("Only universal loss models are supported for LLPR")

import tqdm
# Utility function to compute the covariance matrix for a training set.
# Note that this function computes the covariance step-wise, so it can
# be used to accumulate multiple times on subsets of the same training set
Expand All @@ -1283,7 +1281,7 @@ def compute_covariance(

num_graphs = batch_dict["ptr"].numel() - 1
num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1]
ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs)
ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs, num_atoms)

if include_forces or include_virials or include_stresses:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
Expand All @@ -1309,38 +1307,32 @@ def compute_covariance(
)
cur_weights *= huber_mask
ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5))
self.covariance += (ll_feats / num_atoms).T @ (ll_feats / num_atoms)
self.covariance += ll_feats.T @ ll_feats

if include_forces:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch])
f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch])
cur_f_weights = torch.mul(f_conf_weights, f_forces_weights)
cur_f_weights = cur_f_weights.view(-1, 1).expand(-1, 3)
if is_universal:
huber_mask_force = get_conditional_huber_force_mask(
output["forces"],
batch["forces"],
huber_delta,
)
cur_f_weights = torch.mul(cur_f_weights, huber_mask_force)
f_grads = torch.mul(f_grads, cur_f_weights.unsqueeze(-1)**(0.5))
cur_f_weights *= huber_mask_force
f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)**(0.5))
f_grads = f_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += f_grads.T @ f_grads

if include_virials:
# No Huber mask in the case of virials as it was not used in the
# universal model
cur_v_weights = torch.mul(batch.weight, batch.virials_weight)
cur_v_weights = cur_v_weights.view(-1, 1, 1).expand(-1, 3, 3)
v_grads = torch.mul(v_grads, cur_v_weights.unsqueeze(-1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += v_grads.T @ v_grads

if include_stresses:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
cur_s_weights = torch.mul(batch.weight, batch.stress_weight)
cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3)
if is_universal:
Expand All @@ -1351,8 +1343,6 @@ def compute_covariance(
)
cur_s_weights *= huber_mask_stress
s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)**(0.5))
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss. Here, we don't normalize
s_grads = s_grads.reshape(-1, ll_feats.shape[-1])
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss.
Expand Down
103 changes: 103 additions & 0 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,109 @@ def compute_forces_virials(
return -1 * forces, -1 * virials, stress


@torch.jit.script
def compute_ll_feat_gradients(
ll_feats: torch.Tensor,
displacement: torch.Tensor,
batch_dict: Dict[str, torch.Tensor],
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:

grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(ll_feats[:, 0])]
positions = batch_dict["positions"]

if compute_force and not (compute_virials or compute_stress):
f_grads_list = []
for i in range(ll_feats.shape[-1]):
cur_grad_f = torch.autograd.grad(
[ll_feats[:, i]],
[positions],
grad_outputs=grad_outputs,
retain_graph=(i != ll_feats.shape[-1] - 1),
create_graph=False,
allow_unused=True,
)[0]
if cur_grad_f is None:
cur_grad_f = torch.zeros_like(positions)
f_grads_list.append(cur_grad_f)
f_grads = torch.stack(f_grads_list)
f_grads = f_grads.permute(1, 2, 0)
v_grads = None
s_grads = None

elif compute_force and (compute_virials or compute_stress):
cell = batch_dict["cell"]
f_grads_list = []
v_grads_list = []
s_grads_list = []
for i in range(ll_feats.shape[-1]):
cur_grad_f, cur_grad_v = torch.autograd.grad(
[ll_feats[:, i]],
[positions, displacement],
grad_outputs=grad_outputs,
retain_graph=(i != ll_feats.shape[-1] - 1),
create_graph=False,
allow_unused=True,
)
if cur_grad_f is None:
cur_grad_f = torch.zeros_like(positions)
f_grads_list.append(cur_grad_f)
if cur_grad_v is None:
cur_grad_v = torch.zeros_like(displacement)
v_grads_list.append(cur_grad_v)
f_grads = torch.stack(f_grads_list)
f_grads = f_grads.permute(1, 2, 0) # [num_atoms_batch, 3, num_ll_feats]
v_grads = torch.stack(v_grads_list)
v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats]

if compute_stress:
cell = cell.view(-1, 3, 3)
volume = torch.einsum(
"zi,zi->z",
cell[:, 0, :],
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
).unsqueeze(-1)
s_grads = v_grads / volume.view(-1, 1, 1, 1)
else:
s_grads = None

elif not compute_force and (compute_virials or compute_stress):
cell = batch_dict["cell"]
v_grads_list = []
for i in range(ll_feats.shape[-1]):
cur_grad_v = torch.autograd.grad(
[ll_feats[:, i]],
[displacement],
grad_outputs=grad_outputs,
retain_graph=(i != ll_feats.shape[-1] - 1),
create_graph=False,
allow_unused=True,
)[0]
if cur_grad_v is None:
cur_grad_v = torch.zeros_like(displacement)
v_grads_list.append(cur_grad_v)
v_grads = torch.stack(v_grads_list)
v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats]

if compute_stress:
cell = cell.view(-1, 3, 3)
volume = torch.einsum(
"zi,zi->z",
cell[:, 0, :],
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
).unsqueeze(-1)
s_grads = v_grads / volume.view(-1, 1, 1, 1)
else:
s_grads = None
f_grads = None
else:
raise RuntimeError("Unsupported configuration for computing gradients")

return f_grads, v_grads, s_grads


def get_symmetric_displacement(
positions: torch.Tensor,
unit_shifts: torch.Tensor,
Expand Down

0 comments on commit ab38be6

Please sign in to comment.