Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/mamba_update' into mamba_update
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Jul 24, 2024
2 parents a1d6394 + 11d2028 commit dd97111
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 44 deletions.
36 changes: 1 addition & 35 deletions open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from torch.nn import functional as F
import xformers.ops as xops


def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype):
Expand Down Expand Up @@ -63,31 +62,6 @@ def apply_attention_mask_(bias, attention_mask, queries_dtype):
bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True))


def xformers_attn(queries, keys, values, is_causal, attention_mask=None):
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
# We assume that queries match the last part of the key / value sequences
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1

# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
bias = None
if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None:
bias = xops.LowerTriangularMask()
elif is_causal and (queries.shape[1] > 1 or attention_mask is not None):
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype)
elif not is_causal and attention_mask is not None:
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.")
return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias)


def torch_attn(queries, keys, values, is_causal, attention_mask=None):
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
Expand Down Expand Up @@ -196,15 +170,7 @@ def get_attn_func(
alpha=None,
):
if attn_name == "auto":
return xformers_attn if torch.cuda.is_available() else torch_attn
elif attn_name == "xformers_attn":
return xformers_attn
elif attn_name == "xformers_attn_variable_length":
# Upon changing the input sequence length, xformers attention changes
# the stride dimension of the output tensor. This makes future calls to
# .view() that collapses last two dimensions fail. One thus needs to
# call .contiguous() on the output tensor. [#188]
return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous()
return torch_attn
elif attn_name == "torch_attn":
return torch_attn
elif attn_name == "custom_attn":
Expand Down
36 changes: 30 additions & 6 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from torch import nn
from torch.utils.checkpoint import checkpoint

import xformers.ops as xops

from huggingface_hub import PyTorchModelHubMixin

from open_lm.attention import get_attn_func, xformers_attn, torch_attn
from open_lm.attention import get_attn_func, torch_attn
from open_lm.norms import get_norm_class
from open_lm.positional_embedding.head_rotary import HeadRotaryWithCast
from open_lm.positional_embedding.rotary import RotaryWithCast
Expand Down Expand Up @@ -91,7 +89,7 @@ class Params:
post_embed_norm: bool = False
weight_tying: bool = False
norm_type: nn.Module = nn.LayerNorm
attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn
attn_func: Callable = torch_attn
apply_qk_norm: bool = False
moe_loss_weight: float = 0.1
moe_capacity_factor: float = 1.25
Expand Down Expand Up @@ -267,7 +265,7 @@ def __init__(self, layer_id, args: Params):
if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False)
self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False)
elif args.ffn_type == "swiglu_torch":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
Expand Down Expand Up @@ -496,6 +494,33 @@ def create_params(args):
moe_top_k=cfg.get("moe_top_k", args.moe_top_k),
)

if MambaLMHeadModel is not None:
# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds
class MixerModelOpenLM(MixerModel):
def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs):
assert input_ids is not None or inputs_embeds is not None
hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
return hidden_states

if MambaLMHeadModel is not None:
# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds
Expand Down Expand Up @@ -552,7 +577,6 @@ def __init__(
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)

def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs):
hidden_state = self.backbone(input_ids, inputs_embeds, inference_params)
lm_logits = self.lm_head(hidden_state)
Expand Down
9 changes: 9 additions & 0 deletions open_lm/model_configs/11m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 192,
"n_layers": 8,
"n_heads": 4,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
9 changes: 9 additions & 0 deletions open_lm/model_configs/154m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 576,
"n_layers": 24,
"n_heads": 8,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
9 changes: 9 additions & 0 deletions open_lm/model_configs/1b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 2048,
"n_layers": 24,
"n_heads": 16,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
9 changes: 9 additions & 0 deletions open_lm/model_configs/411m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 1024,
"n_layers": 24,
"n_heads": 8,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
9 changes: 9 additions & 0 deletions open_lm/model_configs/79m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 512,
"n_layers": 8,
"n_heads": 4,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
9 changes: 9 additions & 0 deletions open_lm/model_configs/7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"hidden_dim": 4096,
"n_layers": 32,
"n_heads": 32,
"seq_len": 1024,
"vocab_size": 66816,
"post_embed_norm": false,
"weight_tying": false
}
10 changes: 10 additions & 0 deletions open_lm/model_configs/open_lm_1b_swiglutorch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"hidden_dim": 2048,
"n_layers": 24,
"n_heads": 16,
"seq_len": 2048,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false,
"ffn_type": "swiglu_torch"
}
2 changes: 1 addition & 1 deletion open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def add_model_args(parser):
"--attn-name",
type=str,
default="auto",
choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "custom_attn"],
choices=["auto", "torch_attn", "custom_attn"],
help="type of attention to use",
)
parser.add_argument(
Expand Down
1 change: 0 additions & 1 deletion open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def is_attention_mask_right(attention_mask):
sum_values = torch.sum(attention_mask, dim=1)
# Check if the sum of the mask is equal to the first zero index (meaning that the rest of the sequence after the first 0 is also 0)
is_valid_sequence = (sum_values % attention_mask.shape[1] == first_zero_index).all()

return is_valid_sequence


Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
torch
xformers>=0.0.22
tiktoken
wandb
webdataset
Expand Down

0 comments on commit dd97111

Please sign in to comment.