From 558dfe4a4fd1935d9ce1472f3b84b55ddea4e158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:41:12 +0100 Subject: [PATCH 1/5] remove bugged torch.float32 default --- src/transformers/modeling_attn_mask_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 9658adc55d5c96..80a607cc38f4aa 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -66,7 +66,7 @@ def to_causal_4d( batch_size: int, query_length: int, key_value_length: int, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype, device: Union[torch.device, "str"] = "cpu", ) -> torch.Tensor: """ @@ -98,8 +98,8 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, + dtype: torch.dtype, key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, From 0c0b35c7e29c74c3d7739461254c4d6513492723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:58:10 +0100 Subject: [PATCH 2/5] add test --- tests/test_modeling_utils.py | 68 +++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index fd4297de1c1fa6..0c8cfd04ad64e6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -85,7 +85,12 @@ T5Config, T5ForConditionalGeneration, ) - from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + ) from transformers.modeling_utils import shard_checkpoint # Fake pretrained models for tests @@ -150,6 +155,32 @@ def forward(self, x): def tie_weights(self): self.decoder.weight = self.base.linear.weight + class Prepare4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _prepare_4d_causal_attention_mask( + None, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + return attention_mask + + class Create4dCausalAttentionMaskModel(nn.Module): + def forward(self, inputs_embeds): + batch_size, seq_length, _ = inputs_embeds.shape + past_key_values_length = 4 + attention_mask = _create_4d_causal_attention_mask( + (batch_size, seq_length), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + return attention_mask + + class Prepare4dAttentionMaskModel(nn.Module): + def forward(self, mask, inputs_embeds): + attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype) + return attention_mask + if is_flax_available(): from transformers import FlaxBertModel @@ -1621,3 +1652,38 @@ def test_causal_mask_sliding(self): self.check_to_causal(mask_converter, q_len=3, kv_len=7) # non auto-regressive case self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_torch_compile_fullgraph(self): + model = Prepare4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand([1, 3, 32]) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Create4dCausalAttentionMaskModel() + + inputs_embeds = torch.rand(2, 4, 16) + res_non_compiled = model(inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) + + model = Prepare4dAttentionMaskModel() + + mask = torch.ones(2, 4) + mask[0, :2] = 0 + inputs_embeds = torch.rand(2, 4, 16) + + res_non_compiled = model(mask, inputs_embeds) + + compiled_model = torch.compile(model, fullgraph=True) + res_compiled = compiled_model(mask, inputs_embeds) + + self.assertTrue(torch.equal(res_non_compiled, res_compiled)) From 8358faccb4e7e7240204e6a83c5aa671d1cd6795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 6 Dec 2023 18:03:31 +0100 Subject: [PATCH 3/5] fix tests --- src/transformers/modeling_attn_mask_utils.py | 2 +- tests/test_modeling_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 80a607cc38f4aa..fbfb177815e94b 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -33,7 +33,7 @@ class AttentionMaskConverter: >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5, dtype=torch.float32) tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0c8cfd04ad64e6..e1c37ec2687ed0 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1524,7 +1524,7 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3 for bsz_idx, seq_idx in additional_mask: mask_2d[bsz_idx, seq_idx] = 0 - mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32) assert mask_4d.shape == (bsz, 1, q_len, kv_len) @@ -1560,7 +1560,9 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3 self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): - mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + mask_4d = mask_converter.to_causal_4d( + bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32 + ) if q_len == 1 and mask_converter.sliding_window is None: # no causal mask if q_len is 1 From f262fb770ffeec797af5a3038acebe7bb045fdba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:52:05 +0100 Subject: [PATCH 4/5] fix test --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index fbfb177815e94b..ba24ebd82f1863 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask( # 4d mask is passed through the layers if attention_mask is not None: attention_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype ) else: attention_mask = attn_mask_converter.to_causal_4d( From 3c85215534f0ebb141ac38c27f95edcea7edaef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 12:25:42 +0100 Subject: [PATCH 5/5] fix doc --- src/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index ba24ebd82f1863..2c4a839ae19b39 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -33,7 +33,7 @@ class AttentionMaskConverter: >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5, dtype=torch.float32) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],