Skip to content

Commit

Permalink
remove 2 redundant normalizations
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 19, 2024
1 parent a24e58c commit 206cfc9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def forward(
) -> torch.Tensor:
h = l2_normalize(
torch.lerp(
l2_normalize(x),
x,
l2_normalize(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)),
(
self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling)
Expand All @@ -313,7 +313,7 @@ def forward(

return l2_normalize(
torch.lerp(
l2_normalize(h),
h,
l2_normalize(self.feed_forward(h)),
(self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling)).abs(),
)
Expand Down

0 comments on commit 206cfc9

Please sign in to comment.