Skip to content

Commit

Permalink
Update attention.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley authored Jan 13, 2024
1 parent e0971ff commit 51b8942
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,11 @@ def __init__(

if self.qk_ln or self.qk_gn:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
_div = n_heads if self.qk_gn else 1
self.q_ln = norm_class(self.d_model // _div, device=device)
_div = kv_n_heads if self.qk_gn else 1
self.k_ln = norm_class(self.kv_n_heads * self.head_dim // _div,
device=device)
norm_size = self.head_dim if qk_gn else d_model
self.q_ln = norm_class(norm_size, device=device)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = norm_class(norm_size, device=device)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down

0 comments on commit 51b8942

Please sign in to comment.