Skip to content

Commit

Permalink
Patched FlexAttention, SDPA, Eager with Local Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Dec 17, 2024
1 parent f312eef commit 1e367df
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 38 deletions.
57 changes: 38 additions & 19 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def eager_attention_forward(
qkv: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
output_attentions: Optional[bool] = False,
Expand All @@ -320,9 +321,21 @@ def eager_attention_forward(
scale = module.head_dim**-0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2])

if local_attention != (-1, -1):
# Create position indices
rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)

# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
# Combine with existing mask
expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf"))

attn_weights = attn_weights + expanded_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -385,6 +398,7 @@ def flex_attention_forward(
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
cu_seqlens: torch.Tensor,
block_mask: "BlockMask",
local_attention: Tuple[int, int],
max_seqlen: int,
bs: int,
dim: int,
Expand All @@ -401,7 +415,7 @@ def flex_attention_forward(
query,
key,
value,
block_mask=block_mask,
block_mask=block_mask if local_attention != (-1, -1) else None,
enable_gqa=False,
scale=None,
return_lse=False,
Expand All @@ -416,6 +430,7 @@ def sdpa_attention_forward(
qkv: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
**_kwargs,
Expand All @@ -427,7 +442,22 @@ def sdpa_attention_forward(
query, key = apply_rotary_pos_emb(query, key, cos, sin)

if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attention_mask = attention_mask[:, None, None, :].expand(
attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1]
)

if local_attention != (-1, -1):
# Create position indices
rows = torch.arange(attention_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)

# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
# Combine with existing mask
attention_mask = torch.logical_and(attention_mask, window_mask)

attention_mask = attention_mask.to(torch.bool)

attn_output = F.scaled_dot_product_attention(
query,
Expand Down Expand Up @@ -893,7 +923,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets):
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts)

@torch.compile(dynamic=False)
def create_attention_mask(self, sequence_ids, cu_seqlens, window_size):
"""
Creates a block mask combining sequence masking and local/or global attention masking.
Expand Down Expand Up @@ -1053,23 +1082,12 @@ def forward(

hidden_states = self.embeddings(input_ids)

# expand attention_mask
if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

# create block mask
block_mask = None
if self.config._attn_implementation == "flex_attention":
sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens)

if self.config.local_attention != (-1, -1):
window_size = self.config.local_attention // 2
else:
window_size = max_seqlen

window_size = self.config.local_attention // 2
block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size)
else:
block_mask = None

for encoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -1082,6 +1100,7 @@ def forward(
attention_mask,
position_ids,
cu_seqlens,
block_mask,
max_seqlen,
output_attentions,
)
Expand Down
57 changes: 38 additions & 19 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def eager_attention_forward(
qkv: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
output_attentions: Optional[bool] = False,
Expand All @@ -532,9 +533,21 @@ def eager_attention_forward(
scale = module.head_dim**-0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2])

if local_attention != (-1, -1):
# Create position indices
rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)

# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
# Combine with existing mask
expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf"))

attn_weights = attn_weights + expanded_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -597,6 +610,7 @@ def flex_attention_forward(
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
cu_seqlens: torch.Tensor,
block_mask: "BlockMask",
local_attention: Tuple[int, int],
max_seqlen: int,
bs: int,
dim: int,
Expand All @@ -613,7 +627,7 @@ def flex_attention_forward(
query,
key,
value,
block_mask=block_mask,
block_mask=block_mask if local_attention != (-1, -1) else None,
enable_gqa=False,
scale=None,
return_lse=False,
Expand All @@ -628,6 +642,7 @@ def sdpa_attention_forward(
qkv: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
**_kwargs,
Expand All @@ -639,7 +654,22 @@ def sdpa_attention_forward(
query, key = apply_rotary_pos_emb(query, key, cos, sin)

if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attention_mask = attention_mask[:, None, None, :].expand(
attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1]
)

if local_attention != (-1, -1):
# Create position indices
rows = torch.arange(attention_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)

# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
# Combine with existing mask
attention_mask = torch.logical_and(attention_mask, window_mask)

attention_mask = attention_mask.to(torch.bool)

attn_output = F.scaled_dot_product_attention(
query,
Expand Down Expand Up @@ -1033,7 +1063,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets):
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts)

@torch.compile(dynamic=False)
def create_attention_mask(self, sequence_ids, cu_seqlens, window_size):
"""
Creates a block mask combining sequence masking and local/or global attention masking.
Expand Down Expand Up @@ -1193,23 +1222,12 @@ def forward(

hidden_states = self.embeddings(input_ids)

# expand attention_mask
if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

# create block mask
block_mask = None
if self.config._attn_implementation == "flex_attention":
sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens)

if self.config.local_attention != (-1, -1):
window_size = self.config.local_attention // 2
else:
window_size = max_seqlen

window_size = self.config.local_attention // 2
block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size)
else:
block_mask = None

for encoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -1222,6 +1240,7 @@ def forward(
attention_mask,
position_ids,
cu_seqlens,
block_mask,
max_seqlen,
output_attentions,
)
Expand Down

0 comments on commit 1e367df

Please sign in to comment.