Skip to content

Commit

Permalink
T5 compile compatibilty (#34089)
Browse files Browse the repository at this point in the history
* this worked in normal generation, needs more tests

* fix almost all tests in t5

* nit

* longt5, umt5, mt5

* style

* udop, pix2struct

* more models

* fix some tests

* fix onnx tests

* tracing tests fixed

* compile enabled and tested for t5 models

* fix small bug in slow tests

* [run-slow] t5

* uncomment

* style

* update with new generation refactoring

* nit

* fix copies

* this is the fix, had to change t5 to fix copies

* update

* [run-slow] t5

* [run-slow] t5

* update

* add test for encoder only T5

* clean up after rebase

* fix pop2piano

* add comment

* style

* fix copies after rebase

* fix copies  missed this one
  • Loading branch information
zucchini-nlp authored Oct 22, 2024
1 parent 5077bc0 commit 73d65e6
Show file tree
Hide file tree
Showing 22 changed files with 2,743 additions and 1,178 deletions.
6 changes: 1 addition & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,11 +1475,7 @@ def from_legacy_cache(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []:
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return self.self_attention_cache.get_seq_length(layer_idx)

def reset(self):
if hasattr(self.self_attention_cache, "reset"):
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,8 +1535,12 @@ def _prepare_generation_config(
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
cache_position = (
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1

Expand Down Expand Up @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/longt5/configuration_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig):

model_type = "longt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
455 changes: 299 additions & 156 deletions src/transformers/models/longt5/modeling_longt5.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig):

model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
449 changes: 299 additions & 150 deletions src/transformers/models/mt5/modeling_mt5.py

Large diffs are not rendered by default.

407 changes: 275 additions & 132 deletions src/transformers/models/pix2struct/modeling_pix2struct.py

Large diffs are not rendered by default.

451 changes: 298 additions & 153 deletions src/transformers/models/pop2piano/modeling_pop2piano.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig):

model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
452 changes: 302 additions & 150 deletions src/transformers/models/t5/modeling_t5.py

Large diffs are not rendered by default.

434 changes: 289 additions & 145 deletions src/transformers/models/udop/modeling_udop.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/transformers/models/umt5/configuration_umt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig):

model_type = "umt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}

def __init__(
self,
Expand Down
Loading

0 comments on commit 73d65e6

Please sign in to comment.