Skip to content

Commit

Permalink
Merge branch 'main' into jerry/oras
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychen109 authored Jan 18, 2024
2 parents 513c9b0 + 19ee086 commit a0fc510
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 131 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
Expand All @@ -17,6 +17,7 @@
'triton_flash_attn_fn',
'MultiheadAttention',
'MultiQueryAttention',
'GroupedQueryAttention',
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
Expand Down
39 changes: 22 additions & 17 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,17 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
if key_padding_mask is not None:
raise ValueError('key_padding_mask should be None for flash attn.')
del key_padding_mask
if flash_attn_padding_info is None:
raise ValueError('flash_attn_padding_info is required for flash attn.')
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
Expand Down Expand Up @@ -267,25 +272,24 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
indices_q = flash_attn_padding_info['indices_q']
indices_k = flash_attn_padding_info['indices_k']
indices_v = flash_attn_padding_info['indices_v']
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = bert_padding.index_first_axis(
rearrange(query, 'b s ... -> (b s) ...'), indices_q)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = bert_padding.index_first_axis(
rearrange(key, 'b s ... -> (b s) ...'), indices_k)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = bert_padding.index_first_axis(
rearrange(value, 'b s ... -> (b s) ...'), indices_v)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
Expand Down Expand Up @@ -599,8 +603,8 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -666,11 +670,12 @@ def forward(

extra_attn_kwargs = {}
if self.attn_impl == 'flash':
key_padding_mask = None
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
}

context, attn_weights, past_key_value = self.attn_fn(
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -135,8 +135,8 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
103 changes: 84 additions & 19 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

from __future__ import annotations

import math
import warnings
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
Expand All @@ -24,15 +26,23 @@
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
from flash_attn.layers.rotary import \
RotaryEmbedding as DAILRotaryEmbedding
except Exception as e:
raise e

if is_flash_v1_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
except Exception as e:
raise e

from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -216,6 +226,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
return attention_mask_in_length


def gen_flash_attn_padding_info(
bsz: int,
S: int,
past_key_len: int,
device: torch.device,
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None):
flash_attn_padding_info = {}
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = torch.ones((bsz, past_key_len + S),
dtype=torch.bool,
device=device)
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.empty(bsz, S, 1, device=device), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
_, indices_v, _, _ = unpadding_function(
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
flash_attn_padding_info['indices_v'] = indices_v
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
return flash_attn_padding_info


def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
max_seq_len: int) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
Expand Down Expand Up @@ -246,6 +294,14 @@ class MPTPreTrainedModel(PreTrainedModel):
_no_split_modules = ['MPTBlock']


def _fsdp_wrap_fn(
self: Union[MPTModel, MPTForCausalLM],
module: nn.Module,
) -> bool:
# FSDP Wrap function for MPT Models
return isinstance(module, MPTBlock)


class MPTModel(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
Expand Down Expand Up @@ -515,10 +571,12 @@ def forward(
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
bsz = input_ids.size(0)
S = input_ids.size(1)
x = self.wte(input_ids)
input_device = input_ids.device
elif inputs_embeds is not None:
bsz = inputs_embeds.size(0)
S = inputs_embeds.size(1)
x = inputs_embeds
input_device = inputs_embeds.device
Expand All @@ -530,22 +588,23 @@ def forward(
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_meta_info = None
if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

if self.learned_pos_emb or self.rope:
if self.learned_pos_emb and (S + past_position >
self.config.max_seq_len):
raise ValueError(
Expand Down Expand Up @@ -623,6 +682,12 @@ def forward(

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
flash_attn_padding_info = gen_flash_attn_padding_info(
bsz, S, past_position, x.device, attention_mask_in_length,
attention_mask)

for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None # pyright
Expand All @@ -637,8 +702,8 @@ def forward(
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
)
if presents is not None:
presents += (present,)
Expand Down Expand Up @@ -673,7 +738,7 @@ def param_init_fn(self, module: nn.Module) -> None:

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
return _fsdp_wrap_fn(self, module)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
Expand Down Expand Up @@ -834,7 +899,7 @@ def param_init_fn(self, module: nn.Module) -> None:

# FSDP Wrap function
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
return _fsdp_wrap_fn(self, module)

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def validate_and_get_cluster_info(cluster_id: str,
stripped_runtime = re.sub(
r'[a-zA-Z]', '',
res.spark_version.split('-scala')[0].replace('x-snapshot', ''))
runtime_version = re.sub(r'.-+$', '', stripped_runtime)
runtime_version = re.sub(r'[.-]*$', '', stripped_runtime)
if version.parse(runtime_version) < version.parse(
MINIMUM_SQ_CONNECT_DBR_VERSION):
raise ValueError(
Expand Down
Loading

0 comments on commit a0fc510

Please sign in to comment.