Skip to content

Commit

Permalink
fix sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed May 4, 2024
1 parent 7059235 commit 57842ab
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,16 +1094,10 @@ def _update_causal_mask(
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)


exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)

if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4):
# assume signed int tensor for cache_position
exclude_mask |= torch.arange(target_length, device=device) <= (cache_position.reshape(-1,1) - self.config.sliding_window)

exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window)
causal_mask *= exclude_mask

causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
Expand Down

0 comments on commit 57842ab

Please sign in to comment.