Skip to content

Commit

Permalink
Paligemma: fix static cache test (#33941)
Browse files Browse the repository at this point in the history
* fix

* not flaky anymore + style
  • Loading branch information
zucchini-nlp authored Oct 5, 2024
1 parent 38f9f10 commit 612065e
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 15 deletions.
4 changes: 1 addition & 3 deletions examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
Expand Down
4 changes: 1 addition & 3 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,9 +758,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
Expand Down
11 changes: 4 additions & 7 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
is_training: bool,
token_type_ids: torch.Tensor,
is_training: bool = False,
token_type_ids: torch.Tensor = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down Expand Up @@ -94,7 +94,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)
causal_mask[:, :sequence_length] = 0.0

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
Expand Down Expand Up @@ -378,7 +378,7 @@ def _update_causal_mask(
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)
causal_mask[:, :sequence_length] = 0.0

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
Expand Down Expand Up @@ -593,7 +593,6 @@ def prepare_inputs_for_generation(

dtype = self.get_output_embeddings().weight.dtype
min_dtype = torch.finfo(dtype).min
is_training = token_type_ids is not None and kwargs.get("labels", None) is not None

model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
Expand All @@ -604,8 +603,6 @@ def prepare_inputs_for_generation(
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
is_training=is_training,
token_type_ids=token_type_ids,
)

model_inputs["token_type_ids"] = token_type_ids
Expand Down
3 changes: 2 additions & 1 deletion tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)

# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.image_token_index] = self.pad_token_id
Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4868,7 +4868,6 @@ def test_custom_4d_attention_mask(self):
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

@is_flaky(max_attempts=10) # TODO @raushan: this test is VERY flaky on some VLMs, like paligemma
def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.
Expand Down

0 comments on commit 612065e

Please sign in to comment.