Skip to content

Commit

Permalink
mbart fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 15, 2023
1 parent 83d2108 commit 2144004
Show file tree
Hide file tree
Showing 28 changed files with 521 additions and 451 deletions.
77 changes: 40 additions & 37 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def _prepare_key_values(
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

bsz = hidden_states.shape[0]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
Expand Down Expand Up @@ -229,12 +236,6 @@ def forward(
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -340,13 +341,6 @@ def forward(
if output_attentions:
raise ValueError("BartFlashAttention2 attention does not support output_attentions")

if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

bsz, q_len, _ = hidden_states.size()

# get query proj
Expand Down Expand Up @@ -521,13 +515,6 @@ def forward(
output_attentions=output_attentions,
)

if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

# get query proj
query_states = self.q_proj(hidden_states)

Expand Down Expand Up @@ -647,7 +634,6 @@ class BartDecoderLayer(nn.Module):
def __init__(self, config: BartConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.d_model
self.layer_idx = layer_idx

self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
Expand Down Expand Up @@ -1281,21 +1267,27 @@ def forward(
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
blocks) that can be used to speed up sequential decoding. This typically consists in the
`past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or
`config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. This is also known as
the legacy cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed,
the legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -2168,18 +2160,29 @@ def forward(
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the
`past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or
`config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. This is also known as
the legacy cache format.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed,
the legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
Expand Down
35 changes: 13 additions & 22 deletions src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,13 @@ def _prepare_key_values(
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

bsz = hidden_states.shape[0]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
Expand Down Expand Up @@ -1265,12 +1272,6 @@ def forward(
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -1468,7 +1469,7 @@ def forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
Expand All @@ -1494,12 +1495,9 @@ def forward(
hidden_states = self.self_attn_layer_norm(hidden_states)

# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights, past_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
Expand All @@ -1508,28 +1506,21 @@ def forward(
hidden_states = residual + hidden_states

# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)

# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states

# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value

# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
Expand All @@ -1545,7 +1536,7 @@ def forward(
outputs += (self_attn_weights, cross_attn_weights)

if use_cache:
outputs += (present_key_value,)
outputs += (past_key_value,)

return outputs

Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def _prepare_key_values(
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

bsz = hidden_states.shape[0]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
Expand Down Expand Up @@ -174,12 +181,6 @@ def forward(
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
if past_key_value is not None and self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down
Loading

0 comments on commit 2144004

Please sign in to comment.