Skip to content

Commit

Permalink
Add SDPA support for ESM.
Browse files Browse the repository at this point in the history
  • Loading branch information
wzf03 committed Nov 28, 2024
1 parent 5523e38 commit 72750f6
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
* [ESM](https://huggingface.co/docs/transformers/model_doc/esm#transformers.ESMModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
Expand Down
202 changes: 182 additions & 20 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand All @@ -31,7 +36,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from ...utils import get_torch_version, logging
from .configuration_esm import EsmConfig


Expand Down Expand Up @@ -88,7 +93,6 @@ def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
Expand All @@ -102,12 +106,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=2):
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
t = torch.arange(x.shape[seq_dimension], device=x.device).float()

# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = torch.outer(t.float(), self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]

self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
self._cos_cached = cos.to(dtype=x.dtype)
self._sin_cached = sin.to(dtype=x.dtype)

return self._cos_cached, self._sin_cached

Expand Down Expand Up @@ -370,7 +382,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand All @@ -383,6 +395,113 @@ def forward(
return outputs


class EsmSdpaSelfAttention(EsmSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type)
self.attention_dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. "
"Falling back to the manual attention implementation, but specifying the manual implementation will "
"be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

# Scale the query for rotary embeddings
query_layer = query_layer * self.attention_head_size**-0.5

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.attention_dropout_prob if self.training else 0.0,
is_causal=is_causal,
scale=1.0, # Scale is already applied to query_layer
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


class EsmSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -396,10 +515,17 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


ESM_SELF_ATTENTION_CLASSES = {
"eager": EsmSelfAttention,
"sdpa": EsmSdpaSelfAttention,
}


class EsmAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = EsmSelfAttention(config)
# self.self = EsmSelfAttention(config)
self.self = ESM_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -680,6 +806,7 @@ class EsmPreTrainedModel(PreTrainedModel):
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
_supports_sdpa = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down Expand Up @@ -787,6 +914,9 @@ def __init__(self, config, add_pooling_layer=True):
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)

self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -875,9 +1005,40 @@ def forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

use_sdpa_attention_masks = (
self.attn_implementation == "sdpa"
and self.position_embedding_type in ["absolute", "rotary"]
and head_mask is None
and not output_attentions
)

# Expand the attention mask
if use_sdpa_attention_masks and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand All @@ -886,7 +1047,15 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)

if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

Expand All @@ -897,13 +1066,6 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down
10 changes: 9 additions & 1 deletion tests/models/esm/test_modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import unittest

from parameterized import parameterized

from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from transformers.testing_utils import TestCasePlus, require_torch, require_torch_sdpa, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -266,6 +268,12 @@ def test_torchscript_simple(self):
def test_multi_gpu_data_parallel_forward(self):
pass

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@unittest.skip("ESMFold doesn't support output hidden states in normal way which is required in this test.")
def test_eager_matches_sdpa_inference(self):
pass


@require_torch
class EsmModelIntegrationTest(TestCasePlus):
Expand Down

0 comments on commit 72750f6

Please sign in to comment.