diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 9658adc55d5c96..2c4a839ae19b39 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -33,7 +33,7 @@ class AttentionMaskConverter: >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], @@ -66,7 +66,7 @@ def to_causal_4d( batch_size: int, query_length: int, key_value_length: int, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype, device: Union[torch.device, "str"] = "cpu", ) -> torch.Tensor: """ @@ -98,8 +98,8 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, + dtype: torch.dtype, key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, @@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask( # 4d mask is passed through the layers if attention_mask is not None: attention_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype ) else: attention_mask = attn_mask_converter.to_causal_4d( diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index fd4297de1c1fa6..e1c37ec2687ed0 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -85,7 +85,12 @@ T5Config, T5ForConditionalGeneration, ) - from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + ) from transformers.modeling_utils import shard_checkpoint # Fake pretrained models for tests @@ -150,6 +155,32 @@ def forward(self, x): def tie_weights(self): self.decoder.weight = self.base.linear.weight + class Prepare4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _prepare_4d_causal_attention_mask( + None, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + return attention_mask + + class Create4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _create_4d_causal_attention_mask( + (batch_size, seq_length), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + return attention_mask + + class Prepare4dAttentionMaskModel(nn.Module): + def forward(self, mask, inputs_embeds): + attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype) + return attention_mask + if is_flax_available(): from transformers import FlaxBertModel @@ -1493,7 +1524,7 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3 for bsz_idx, seq_idx in additional_mask: mask_2d[bsz_idx, seq_idx] = 0 - mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32) assert mask_4d.shape == (bsz, 1, q_len, kv_len) @@ -1529,7 +1560,9 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3 self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): - mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + mask_4d = mask_converter.to_causal_4d( + bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32 + ) if q_len == 1 and mask_converter.sliding_window is None: # no causal mask if q_len is 1 @@ -1621,3 +1654,38 @@ def test_causal_mask_sliding(self): self.check_to_causal(mask_converter, q_len=3, kv_len=7) # non auto-regressive case self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_torch_compile_fullgraph(self): + model = Prepare4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand([1, 3, 32]) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Create4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand(2, 4, 16) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Prepare4dAttentionMaskModel() + + mask = torch.ones(2, 4) + mask[0, :2] = 0 + inputs_embeds = torch.rand(2, 4, 16) + + res_non_compiled = model(mask, inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(mask, inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled))