diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 12a0e5e3f2..4eac3d2105 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -202,7 +202,6 @@ def scaled_multihead_dot_product_attention( window_mask, diagonal=-sliding_window_size, ) - window_mask = window_mask[-s_q:, -s_k:] window_mask = ~window_mask attn_weight = attn_weight.masked_fill( window_mask.view(1, 1, s_q, s_k),