Skip to content

Commit

Permalink
[⚠️ removed a default argument] Make AttentionMaskConverter compati…
Browse files Browse the repository at this point in the history
…ble with `torch.compile(..., fullgraph=True)` (#27868)

* remove bugged torch.float32 default

* add test

* fix tests

* fix test

* fix doc
  • Loading branch information
fxmarty authored Dec 8, 2023
1 parent 633215b commit 307a7d0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AttentionMaskConverter:
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter = AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
Expand Down Expand Up @@ -66,7 +66,7 @@ def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -98,8 +98,8 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
dtype: torch.dtype,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
Expand Down Expand Up @@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask(
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
Expand Down
74 changes: 71 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@
T5Config,
T5ForConditionalGeneration,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import shard_checkpoint

# Fake pretrained models for tests
Expand Down Expand Up @@ -150,6 +155,32 @@ def forward(self, x):
def tie_weights(self):
self.decoder.weight = self.base.linear.weight

class Prepare4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _prepare_4d_causal_attention_mask(
None, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
return attention_mask

class Create4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _create_4d_causal_attention_mask(
(batch_size, seq_length),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
return attention_mask

class Prepare4dAttentionMaskModel(nn.Module):
def forward(self, mask, inputs_embeds):
attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
return attention_mask


if is_flax_available():
from transformers import FlaxBertModel
Expand Down Expand Up @@ -1493,7 +1524,7 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3
for bsz_idx, seq_idx in additional_mask:
mask_2d[bsz_idx, seq_idx] = 0

mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)

assert mask_4d.shape == (bsz, 1, q_len, kv_len)

Expand Down Expand Up @@ -1529,7 +1560,9 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)

def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
mask_4d = mask_converter.to_causal_4d(
bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
)

if q_len == 1 and mask_converter.sliding_window is None:
# no causal mask if q_len is 1
Expand Down Expand Up @@ -1621,3 +1654,38 @@ def test_causal_mask_sliding(self):
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)

def test_torch_compile_fullgraph(self):
model = Prepare4dCausalAttentionMaskModel()

inputs_embeds = torch.rand([1, 3, 32])
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)

res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Create4dCausalAttentionMaskModel()

inputs_embeds = torch.rand(2, 4, 16)
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Prepare4dAttentionMaskModel()

mask = torch.ones(2, 4)
mask[0, :2] = 0
inputs_embeds = torch.rand(2, 4, 16)

res_non_compiled = model(mask, inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(mask, inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

0 comments on commit 307a7d0

Please sign in to comment.