From 51b8942e64ecc5d6372808e80329284af29bee9a Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Sat, 13 Jan 2024 12:40:13 -0800 Subject: [PATCH] Update attention.py --- llmfoundry/models/layers/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9c8cc50620..ab8d4c4ccf 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -574,11 +574,11 @@ def __init__( if self.qk_ln or self.qk_gn: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - _div = n_heads if self.qk_gn else 1 - self.q_ln = norm_class(self.d_model // _div, device=device) - _div = kv_n_heads if self.qk_gn else 1 - self.k_ln = norm_class(self.kv_n_heads * self.head_dim // _div, - device=device) + norm_size = self.head_dim if qk_gn else d_model + self.q_ln = norm_class(norm_size, device=device) + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = norm_class(norm_size, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn