Skip to content

Commit

Permalink
fix copy
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen committed May 13, 2024
1 parent 1b8ed08 commit 060af34
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,37 +1215,27 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand Down

0 comments on commit 060af34

Please sign in to comment.