From 57842ab9f39a10aee2e940608ea4f3675224adb7 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 4 May 2024 07:23:42 +0200 Subject: [PATCH] fix sliding window --- src/transformers/models/mistral/modeling_mistral.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ca96e778c4e43f..d2f711f143a02b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1094,16 +1094,10 @@ def _update_causal_mask( ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4): - # assume signed int tensor for cache_position - exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window) - + exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)