Skip to content

Commit

Permalink
update KANNonLinearReadoutBlock: delete additional linear layer which…
Browse files Browse the repository at this point in the history
… shows better performance
  • Loading branch information
Hongyu-yu committed Oct 27, 2024
1 parent 7ed86ba commit 1c5b0fd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
# self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
assert MLP_irreps.dim >= 8, "MLP_irreps at least 8!"
dim = [MLP_irreps.dim, MLP_irreps.dim // 2, MLP_irreps.dim // 4, irrep_out.dim]
self.kan = MultKAN(
Expand All @@ -138,7 +138,7 @@ def forward(
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
x1 = self.linear_1(x)
return self.kan(x1) + self.linear_2(x) # [n_nodes, irrep_out.dim]
return self.kan(x1) # + self.linear_2(x) # [n_nodes, irrep_out.dim]

def _make_tracing_inputs(self, n: int):
return [
Expand Down

0 comments on commit 1c5b0fd

Please sign in to comment.