diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 4eac3d2105..acf231558e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -178,7 +178,7 @@ def scaled_multihead_dot_product_attention( min_val, ) - if is_causal and (not q.size(2) == 1): + if is_causal and (not s_q == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32) causal_mask = causal_mask.tril() @@ -194,14 +194,21 @@ def scaled_multihead_dot_product_attention( window_mask = torch.ones((s_q, s_k), dtype=torch.bool, device=attn_weight.device) - window_mask = torch.tril( - window_mask, - diagonal=sliding_window_size, - ) - window_mask = torch.triu( - window_mask, - diagonal=-sliding_window_size, - ) + if (not s_q == 1): + if s_q != s_k: + raise ValueError( + 'Number of queries should be equal to the number of keys.', + ) + window_mask = torch.tril( + window_mask, + diagonal=sliding_window_size, + ) + window_mask = torch.triu( + window_mask, + diagonal=-sliding_window_size, + ) + else: + window_mask[:, :-(sliding_window_size + 1)] = False window_mask = ~window_mask attn_weight = attn_weight.masked_fill( window_mask.view(1, 1, s_q, s_k),