Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[⚠️ removed a default argument] Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) #27868

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AttentionMaskConverter:
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter = AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, 5)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
Expand Down Expand Up @@ -66,7 +66,7 @@ def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine to remove the default (that is the cause of the error), as in _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask and _create_4d_causal_attention_mask privately exposed methods dtype is always passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to_causal_4d is part of the public api no? O this is breaking 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it is used nowhere else than in _create_4d_causal_attention_mask, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, cc @patrickvonplaten I don't think the mask converter API was meant to be exposed?

device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -98,8 +98,8 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
dtype: torch.dtype,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
Expand Down Expand Up @@ -215,7 +215,7 @@ def _prepare_4d_causal_attention_mask(
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
Expand Down
74 changes: 71 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@
T5Config,
T5ForConditionalGeneration,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import shard_checkpoint

# Fake pretrained models for tests
Expand Down Expand Up @@ -150,6 +155,32 @@ def forward(self, x):
def tie_weights(self):
self.decoder.weight = self.base.linear.weight

class Prepare4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _prepare_4d_causal_attention_mask(
None, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
return attention_mask

class Create4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _create_4d_causal_attention_mask(
(batch_size, seq_length),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
return attention_mask

class Prepare4dAttentionMaskModel(nn.Module):
def forward(self, mask, inputs_embeds):
attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
return attention_mask


if is_flax_available():
from transformers import FlaxBertModel
Expand Down Expand Up @@ -1493,7 +1524,7 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3
for bsz_idx, seq_idx in additional_mask:
mask_2d[bsz_idx, seq_idx] = 0

mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)

assert mask_4d.shape == (bsz, 1, q_len, kv_len)

Expand Down Expand Up @@ -1529,7 +1560,9 @@ def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)

def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
mask_4d = mask_converter.to_causal_4d(
bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
)

if q_len == 1 and mask_converter.sliding_window is None:
# no causal mask if q_len is 1
Expand Down Expand Up @@ -1621,3 +1654,38 @@ def test_causal_mask_sliding(self):
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)

def test_torch_compile_fullgraph(self):
model = Prepare4dCausalAttentionMaskModel()

inputs_embeds = torch.rand([1, 3, 32])
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)

res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Create4dCausalAttentionMaskModel()

inputs_embeds = torch.rand(2, 4, 16)
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Prepare4dAttentionMaskModel()

mask = torch.ones(2, 4)
mask[0, :2] = 0
inputs_embeds = torch.rand(2, 4, 16)

res_non_compiled = model(mask, inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(mask, inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))
Loading