Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[⚠️ removed a default argument] Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) #27868

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine to remove the default (that is the cause of the error), as in _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask and _create_4d_causal_attention_mask privately exposed methods dtype is always passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to_causal_4d is part of the public api no? O this is breaking 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it is used nowhere else than in _create_4d_causal_attention_mask, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, cc @patrickvonplaten I don't think the mask converter API was meant to be exposed?

device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -98,8 +98,8 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
dtype: torch.dtype,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
Expand Down
68 changes: 67 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@
T5Config,
T5ForConditionalGeneration,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import shard_checkpoint

# Fake pretrained models for tests
Expand Down Expand Up @@ -150,6 +155,32 @@ def forward(self, x):
def tie_weights(self):
self.decoder.weight = self.base.linear.weight

class Prepare4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _prepare_4d_causal_attention_mask(
None, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
return attention_mask

class Create4dCausalAttentionMaskModel(nn.Module):
def forward(self, inputs_embeds):
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = 4
attention_mask = _create_4d_causal_attention_mask(
(batch_size, seq_length),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
return attention_mask

class Prepare4dAttentionMaskModel(nn.Module):
def forward(self, mask, inputs_embeds):
attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
return attention_mask


if is_flax_available():
from transformers import FlaxBertModel
Expand Down Expand Up @@ -1621,3 +1652,38 @@ def test_causal_mask_sliding(self):
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)

def test_torch_compile_fullgraph(self):
model = Prepare4dCausalAttentionMaskModel()

inputs_embeds = torch.rand([1, 3, 32])
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)

res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Create4dCausalAttentionMaskModel()

inputs_embeds = torch.rand(2, 4, 16)
res_non_compiled = model(inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))

model = Prepare4dAttentionMaskModel()

mask = torch.ones(2, 4)
mask[0, :2] = 0
inputs_embeds = torch.rand(2, 4, 16)

res_non_compiled = model(mask, inputs_embeds)

compiled_model = torch.compile(model, fullgraph=True)
res_compiled = compiled_model(mask, inputs_embeds)

self.assertTrue(torch.equal(res_non_compiled, res_compiled))
Loading