Skip to content

Commit

Permalink
test inner cutoff to the density block
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 22, 2024
1 parent 49fb8b2 commit dfa93e9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
17 changes: 14 additions & 3 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def forward(
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
edge_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

Expand Down Expand Up @@ -659,6 +660,7 @@ def forward(
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
edge_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
Expand Down Expand Up @@ -741,6 +743,7 @@ def forward(
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
edge_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
Expand All @@ -766,6 +769,9 @@ class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None

# Inner radial basis
self.inner_cutoff = PolynomialCutoff(r_max=2.0, p=5)
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
Expand Down Expand Up @@ -833,13 +839,15 @@ def forward(
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
edge_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
inner_cutoff = self.inner_cutoff(edge_lengths)
edge_density = torch.tanh(self.density_fn(edge_feats * inner_cutoff) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
Expand All @@ -862,7 +870,8 @@ class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
if not hasattr(self, "cueq_config"):
self.cueq_config = None

# Cutoff
self.inner_cutoff = PolynomialCutoff(r_max=2.0, p=5)
# First linear
self.linear_up = Linear(
self.node_feats_irreps,
Expand Down Expand Up @@ -931,14 +940,16 @@ def forward(
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
edge_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
inner_cutoff = self.inner_cutoff(edge_lengths)
edge_density = torch.tanh(self.density_fn(edge_feats * inner_cutoff) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
Expand Down
2 changes: 2 additions & 0 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def forward(
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
edge_lengths=lengths,
)
node_feats = product(
node_feats=node_feats,
Expand Down Expand Up @@ -415,6 +416,7 @@ def forward(
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
edge_lengths=lengths,
)
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"]
Expand Down

0 comments on commit dfa93e9

Please sign in to comment.