From 32c596d679503fcbcb26687e55dcb049b7dcfe25 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 10 Apr 2024 16:56:10 -0700 Subject: [PATCH 01/33] allow inputs_embeds in model input args --- open_lm/model.py | 1 - open_lm/utils/transformers/hf_model.py | 9 +++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index 0c979c40..b1f0ba5c 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -395,7 +395,6 @@ def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, use_ x = inputs_embeds else: raise ValueError("Either input_ids or inputs_embeds must be provided.") - x = self.post_embed_norm(x) if past_key_values is None: diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 83353a19..c66e77d1 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -106,6 +106,15 @@ def forward( attention_mask=attention_mask, ) loss = None + if labels is not None: + loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) + + output = CausalLMOutputWithPast( + logits=logits, + past_key_values=past_key_values, + loss=loss + ) + loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() From 9e8abdf5f481680005b8f796bedc81730e8c3893 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Thu, 11 Apr 2024 13:25:39 -0700 Subject: [PATCH 02/33] linted --- open_lm/utils/transformers/hf_model.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index c66e77d1..ae8bd4dc 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -105,15 +105,7 @@ def forward( use_cache=use_cache, attention_mask=attention_mask, ) - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) - output = CausalLMOutputWithPast( - logits=logits, - past_key_values=past_key_values, - loss=loss - ) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() @@ -123,7 +115,11 @@ def forward( shift_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) + output = CausalLMOutputWithPast( + logits=logits, + past_key_values=past_key_values, + loss=loss + ) return output def prepare_inputs_for_generation( From 7aa33c32e68caf3977f72474f379062b4b37dee9 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Mon, 15 Apr 2024 12:45:16 -0700 Subject: [PATCH 03/33] few fixes llm-foundry removed from requirements and added in requirements_test mosaicml added in requirements version bump try except around llm-foundry import --- open_lm/utils/llm_foundry_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index 166d42a2..b356ce8d 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -50,5 +50,5 @@ def __init__(self, model, tokenizer): shift_labels=True, ) - def generate(self, input_ids=None, inputs_embeds=None, **kwargs): + def generate(self, input_ids=None, **kwargs): return super().generate(input_ids=input_ids, **kwargs) From 77394868bfe88feab2a179b74a580c63e87be5f6 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 17 Apr 2024 17:09:35 -0700 Subject: [PATCH 04/33] update Mamba compatibility --- open_lm/model.py | 38 ++++++++++++++++++++------ open_lm/model_configs/mamba_7b.json | 9 ++++-- open_lm/utils/transformers/hf_model.py | 7 +---- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index b1f0ba5c..9b6b429a 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -99,6 +99,19 @@ class Params: positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" +@dataclass +class MambaParams: + d_model: int = None + n_layer: int = None + vocab_size: int = None + seq_len: int = None + ssm_cfg: dict = None + rms_norm: bool = None + residual_in_fp32: bool = None + fused_add_norm: bool = None + pad_vocab_size_multiple: int = None + tie_embeddings: bool = None + weight_tying: bool = None def get_pos_embed(args: Params): head_dim = args.dim // args.n_heads @@ -439,12 +452,19 @@ def create_params(args): # If a parameter is not in the model config, we use the args parameter if "mamba" in args.model: - return { - "d_model": cfg["d_model"], - "n_layer": cfg["n_layer"], - "vocab_size": cfg["vocab_size"], - "seq_len": cfg["seq_len"], - } + return MambaParams( + d_model=cfg["d_model"], + n_layer=cfg["n_layer"], + vocab_size=cfg["vocab_size"], + seq_len=cfg["seq_len"], + ssm_cfg={}, + rms_norm=cfg["rms_norm"], + residual_in_fp32=cfg["residual_in_fp32"], + fused_add_norm=cfg["fused_add_norm"], + pad_vocab_size_multiple=cfg["pad_vocab_size_multiple"], + tie_embeddings=cfg.get("weight_tying", False), + weight_tying=cfg.get("weight_tying", False), + ) else: return Params( dim=cfg["hidden_dim"], @@ -481,10 +501,10 @@ def __init__(self, params): ) super().__init__() - self.seq_len = params.pop("seq_len") - self.vocab_size = params["vocab_size"] + self.vocab_size = params.vocab_size + self.seq_len = params.seq_len - self.model = MambaLMHeadModel(**params) + self.model = MambaLMHeadModel(params) def reset_parameters(self): return diff --git a/open_lm/model_configs/mamba_7b.json b/open_lm/model_configs/mamba_7b.json index 61d67c0f..e9494ec8 100644 --- a/open_lm/model_configs/mamba_7b.json +++ b/open_lm/model_configs/mamba_7b.json @@ -2,5 +2,10 @@ "d_model": 4096, "n_layer": 64, "vocab_size": 50432, - "seq_len": 2048 -} + "seq_len": 2048, + "ssm_cfg": {}, + "rms_norm": true, + "residual_in_fp32": true, + "fused_add_norm": true, + "pad_vocab_size_multiple": 8 +} \ No newline at end of file diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index ae8bd4dc..83353a19 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -105,7 +105,6 @@ def forward( use_cache=use_cache, attention_mask=attention_mask, ) - loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() @@ -115,11 +114,7 @@ def forward( shift_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - output = CausalLMOutputWithPast( - logits=logits, - past_key_values=past_key_values, - loss=loss - ) + output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) return output def prepare_inputs_for_generation( From d4681816a12f4cd4b2a1acfce9ad32c9e7cb2436 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 17 Apr 2024 17:29:47 -0700 Subject: [PATCH 05/33] MambaParams default values --- open_lm/model.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 9b6b429a..b937aa0b 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -3,7 +3,7 @@ import re from copy import deepcopy from pathlib import Path -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable import torch @@ -101,17 +101,17 @@ class Params: @dataclass class MambaParams: - d_model: int = None - n_layer: int = None - vocab_size: int = None - seq_len: int = None - ssm_cfg: dict = None - rms_norm: bool = None - residual_in_fp32: bool = None - fused_add_norm: bool = None - pad_vocab_size_multiple: int = None - tie_embeddings: bool = None - weight_tying: bool = None + d_model: int = 2560 + n_layer: int = 64 + vocab_size: int = 50277 + seq_len: int = 2048 + ssm_cfg: dict = field(default_factory=dict) + rms_norm: bool = True + residual_in_fp32: bool = True + fused_add_norm: bool = True + pad_vocab_size_multiple: int = 8 + tie_embeddings: bool = True + weight_tying: bool = False def get_pos_embed(args: Params): head_dim = args.dim // args.n_heads From b14ef5093a6ec9e9addc4762062d624e33ed36f1 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 17 Apr 2024 18:41:29 -0700 Subject: [PATCH 06/33] add mamba dataclass args --- open_lm/model.py | 2 ++ open_lm/utils/transformers/hf_config.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index b937aa0b..74e88148 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -99,6 +99,7 @@ class Params: positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" + @dataclass class MambaParams: d_model: int = 2560 @@ -113,6 +114,7 @@ class MambaParams: tie_embeddings: bool = True weight_tying: bool = False + def get_pos_embed(args: Params): head_dim = args.dim // args.n_heads if args.positional_embedding_type == "rotary": diff --git a/open_lm/utils/transformers/hf_config.py b/open_lm/utils/transformers/hf_config.py index edf3839f..fa04515a 100644 --- a/open_lm/utils/transformers/hf_config.py +++ b/open_lm/utils/transformers/hf_config.py @@ -40,5 +40,5 @@ def __init__( def set_params(self, params: Params): self.tie_word_embeddings = params.weight_tying - for field in fields(Params): + for field in fields(params): setattr(self, field.name, getattr(params, field.name)) From 35e2488d61cf5ae32c8f5903ed5df4c604f44961 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 17 Apr 2024 18:42:23 -0700 Subject: [PATCH 07/33] small fix in hf_config --- open_lm/utils/transformers/hf_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_config.py b/open_lm/utils/transformers/hf_config.py index fa04515a..edf3839f 100644 --- a/open_lm/utils/transformers/hf_config.py +++ b/open_lm/utils/transformers/hf_config.py @@ -40,5 +40,5 @@ def __init__( def set_params(self, params: Params): self.tie_word_embeddings = params.weight_tying - for field in fields(params): + for field in fields(Params): setattr(self, field.name, getattr(params, field.name)) From 762f0a91611bc7199758743dd442bedffd0c52bc Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 17 Apr 2024 18:56:49 -0700 Subject: [PATCH 08/33] set the correct kind of model in OpenLMModel --- open_lm/utils/transformers/hf_config.py | 2 +- open_lm/utils/transformers/hf_model.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/open_lm/utils/transformers/hf_config.py b/open_lm/utils/transformers/hf_config.py index edf3839f..4e49c10e 100644 --- a/open_lm/utils/transformers/hf_config.py +++ b/open_lm/utils/transformers/hf_config.py @@ -40,5 +40,5 @@ def __init__( def set_params(self, params: Params): self.tie_word_embeddings = params.weight_tying - for field in fields(Params): + for field in fields(params.__class__): setattr(self, field.name, getattr(params, field.name)) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 83353a19..f8d4688e 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -3,7 +3,7 @@ from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from open_lm.utils.transformers.hf_config import OpenLMConfig -from open_lm.model import Transformer, create_params +from open_lm.model import Transformer, create_params, Params, MambaParams, Mamba import torch import torch.nn as nn from typing import Union, Tuple, Optional, List @@ -23,7 +23,10 @@ def __init__(self, config): super().__init__(config) self.supports_gradient_checkpointing = True - self.model = Transformer(params) + if isinstance(params, Params): + self.model = Transformer(params) + elif isinstance(params, MambaParams): + self.model = Mamba(params) @property def gradient_checkpointing(self): From 7fb747f19d170ae84a1916fd15dcaef296cfbbfe Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Mon, 22 Apr 2024 18:15:13 -0700 Subject: [PATCH 09/33] get_input_embeddings for mamba --- .dockerignore | 15 +++++++++++++++ open_lm/utils/transformers/hf_model.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/.dockerignore b/.dockerignore index 52288e50..3f6e7733 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,18 @@ venv wandb logs checkpoints +data +results +preproc_data +logs +wandb +*.pt +.pytest_cache +.vscode +.git +tmp +attention_logs +not_val_shard +attention_logs +training/eval_data + diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index f8d4688e..94a49790 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -50,10 +50,16 @@ def __init__(self, config): self.post_init() def get_input_embeddings(self): - return self.model.tok_embeddings + if isinstance(self.model, Mamba): + return self.model.backbone.embedding + else: + return self.model.tok_embeddings def set_input_embeddings(self, value): - self.model.tok_embeddings = value + if isinstance(self.model, Mamba): + self.model.backbone.embedding = value + else: + self.model.tok_embeddings = value def get_output_embeddings(self): return self.model.get_output_embeddings() From aebb9d5b2b47d046d6a67c9f168e46e116dafcef Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 14:26:04 -0700 Subject: [PATCH 10/33] [wip] add inputs_embeddings in mamba class --- open_lm/model.py | 64 +++++++++++++++++++++++++- open_lm/utils/transformers/hf_model.py | 2 +- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 74e88148..1832d782 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -31,7 +31,7 @@ MoEArgs = None try: # optional import - from mamba_ssm import MambaLMHeadModel + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MixerModel except ImportError: MambaLMHeadModel = None @@ -492,6 +492,65 @@ def create_params(args): moe_top_k=cfg.get("moe_top_k", args.moe_top_k), ) +# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds +class MixerModelOpenLM(MixerModel): + def forward(self, input_ids, inputs_embeds, inference_params=None): + hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + return hidden_states + + +# This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel + +class MambaLMHeadModelOpenLM(MambaLMHeadModel): + def __init__( + self, + config, + initializer_cfg=None, + device=None, + dtype=None, + ) -> None: + super().__init__(config, initializer_cfg, device, dtype) + d_model = config.d_model + n_layer = config.n_layer + vocab_size = config.vocab_size + ssm_cfg = config.ssm_cfg + rms_norm = config.rms_norm + residual_in_fp32 = config.residual_in_fp32 + fused_add_norm = config.fused_add_norm + factory_kwargs = {"device": device, "dtype": dtype} + self.backbone = MixerModelOpenLM( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + ssm_cfg=ssm_cfg, + rms_norm=rms_norm, + initializer_cfg=initializer_cfg, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, + ) + + class Mamba(nn.Module): # Experimental architecture, please "pip install mamba-ssm" @@ -505,8 +564,9 @@ def __init__(self, params): super().__init__() self.vocab_size = params.vocab_size self.seq_len = params.seq_len + self.dim = params.d_model - self.model = MambaLMHeadModel(params) + self.model = MambaLMHeadModelOpenLM(params) def reset_parameters(self): return diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 94a49790..bb1bfae6 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -57,7 +57,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): if isinstance(self.model, Mamba): - self.model.backbone.embedding = value + self.model.model.backbone.embedding = value else: self.model.tok_embeddings = value From 3ca3986a3846c4a6426d36d15ca78473498e8b0f Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 14:33:51 -0700 Subject: [PATCH 11/33] [wip] add inputs_embeddings in mamba class --- open_lm/utils/transformers/hf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index bb1bfae6..032f0e1b 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -51,7 +51,7 @@ def __init__(self, config): def get_input_embeddings(self): if isinstance(self.model, Mamba): - return self.model.backbone.embedding + return self.model.model.backbone.embedding else: return self.model.tok_embeddings From c21129e11f9f640c032dc91d02ce5e1110535e89 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 14:46:24 -0700 Subject: [PATCH 12/33] [wip] add inputs_embeddings in mamba class --- open_lm/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 1832d782..38a4a696 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -571,9 +571,9 @@ def __init__(self, params): def reset_parameters(self): return - def forward(self, x): - out = self.model(x).logits - return out, None, None + def forward(self, input_ids, inputs_embeds=None, inference_params=None): + out = self.model(input_ids, inputs_embeds, inference_params) + return out.logits, None, None def create_model(args): From f5ae4d450a9913362cd44d92fd762a84753a7bb8 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 15:17:57 -0700 Subject: [PATCH 13/33] [wip] Mamba kwargs --- open_lm/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index 38a4a696..acf3378f 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -571,7 +571,7 @@ def __init__(self, params): def reset_parameters(self): return - def forward(self, input_ids, inputs_embeds=None, inference_params=None): + def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): out = self.model(input_ids, inputs_embeds, inference_params) return out.logits, None, None From f89e36c0e28de9ebcd1fd7e52836c0cdbf3cecbd Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 15:31:14 -0700 Subject: [PATCH 14/33] [wip] Mamba optional inputs --- open_lm/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index acf3378f..bab27620 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -492,9 +492,11 @@ def create_params(args): moe_top_k=cfg.get("moe_top_k", args.moe_top_k), ) + # This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds class MixerModelOpenLM(MixerModel): - def forward(self, input_ids, inputs_embeds, inference_params=None): + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None): + assert input_ids is not None or inputs_embeds is not None hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds residual = None for layer in self.layers: From a8fda98cf41b2276e9b5920d493aebfb24e0a188 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 15:57:27 -0700 Subject: [PATCH 15/33] [wip] Mamba inputs propagated --- open_lm/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index bab27620..04a88afa 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -32,6 +32,7 @@ try: # optional import from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MixerModel + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: MambaLMHeadModel = None @@ -495,7 +496,7 @@ def create_params(args): # This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds class MixerModelOpenLM(MixerModel): - def forward(self, input_ids=None, inputs_embeds=None, inference_params=None): + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): assert input_ids is not None or inputs_embeds is not None hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds residual = None @@ -522,7 +523,6 @@ def forward(self, input_ids=None, inputs_embeds=None, inference_params=None): # This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel - class MambaLMHeadModelOpenLM(MambaLMHeadModel): def __init__( self, @@ -551,7 +551,8 @@ def __init__( residual_in_fp32=residual_in_fp32, **factory_kwargs, ) - + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): + return self.backbone(input_ids, inputs_embeds, inference_params) class Mamba(nn.Module): @@ -574,7 +575,7 @@ def reset_parameters(self): return def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): - out = self.model(input_ids, inputs_embeds, inference_params) + out = self.model(input_ids, inputs_embeds, inference_params, **kwargs) return out.logits, None, None From 8a88e0df469f86da10012a29c337041062a70768 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 23 Apr 2024 16:07:51 -0700 Subject: [PATCH 16/33] [wip] Mamba output fix --- open_lm/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 04a88afa..1a2ed9da 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -552,7 +552,9 @@ def __init__( **factory_kwargs, ) def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): - return self.backbone(input_ids, inputs_embeds, inference_params) + hidden_state = self.backbone(input_ids, inputs_embeds, inference_params) + lm_logits = self.lm_head(hidden_state) + return lm_logits, hidden_state, inference_params class Mamba(nn.Module): @@ -575,8 +577,8 @@ def reset_parameters(self): return def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): - out = self.model(input_ids, inputs_embeds, inference_params, **kwargs) - return out.logits, None, None + logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs) + return logits, hidden_state, inference_params def create_model(args): From 813cf7258e3511629a5429f023dfbb8c0076093a Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Apr 2024 16:18:44 -0700 Subject: [PATCH 17/33] Filter keys keywords in load model (remove some layers from the state dict based on names) --- open_lm/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/open_lm/main.py b/open_lm/main.py index 7c80f558..d7247e19 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -102,7 +102,7 @@ def get_state_dict(name): return sd -def load_model(args, model, different_seed=False): +def load_model(args, model, different_seed=False, filter_keys=None): checkpoint = pt_load(args.resume, map_location="cpu") if "epoch" in checkpoint: if not different_seed and "shard_shuffle_seed" in checkpoint: @@ -126,6 +126,8 @@ def load_model(args, model, different_seed=False): sd = {k[len("module.") :]: v for k, v in sd.items()} if "_orig_mod" in next(iter(sd.items()))[0]: sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} + if filter_keys is not None: + sd = {k: v for k, v in sd.items() if not any([x in k for x in filter_keys])} if args.fsdp: model.load_state_dict(sd) elif args.distributed: From 81b6aac405da2619fbac913e2f5ee9cfa784c81b Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Apr 2024 16:54:43 -0700 Subject: [PATCH 18/33] Mamba import optional and its dependencies too --- open_lm/model.py | 170 +++++++++++++++++++++++------------------------ 1 file changed, 85 insertions(+), 85 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 1a2ed9da..47db88c1 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -493,92 +493,92 @@ def create_params(args): moe_top_k=cfg.get("moe_top_k", args.moe_top_k), ) - -# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds -class MixerModelOpenLM(MixerModel): - def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): - assert input_ids is not None or inputs_embeds is not None - hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds - residual = None - for layer in self.layers: - hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params - ) - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) - else: - # Set prenorm=False here since we don't need the residual - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn - hidden_states = fused_add_norm_fn( - hidden_states, - self.norm_f.weight, - self.norm_f.bias, - eps=self.norm_f.eps, - residual=residual, - prenorm=False, - residual_in_fp32=self.residual_in_fp32, - ) - return hidden_states - - -# This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel -class MambaLMHeadModelOpenLM(MambaLMHeadModel): - def __init__( - self, - config, - initializer_cfg=None, - device=None, - dtype=None, - ) -> None: - super().__init__(config, initializer_cfg, device, dtype) - d_model = config.d_model - n_layer = config.n_layer - vocab_size = config.vocab_size - ssm_cfg = config.ssm_cfg - rms_norm = config.rms_norm - residual_in_fp32 = config.residual_in_fp32 - fused_add_norm = config.fused_add_norm - factory_kwargs = {"device": device, "dtype": dtype} - self.backbone = MixerModelOpenLM( - d_model=d_model, - n_layer=n_layer, - vocab_size=vocab_size, - ssm_cfg=ssm_cfg, - rms_norm=rms_norm, - initializer_cfg=initializer_cfg, - fused_add_norm=fused_add_norm, - residual_in_fp32=residual_in_fp32, - **factory_kwargs, - ) - def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): - hidden_state = self.backbone(input_ids, inputs_embeds, inference_params) - lm_logits = self.lm_head(hidden_state) - return lm_logits, hidden_state, inference_params - - -class Mamba(nn.Module): - # Experimental architecture, please "pip install mamba-ssm" - # https://arxiv.org/abs/2312.00752 - def __init__(self, params): - if MambaLMHeadModel is None: - raise ImportError( - "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." +if MambaLMHeadModel is not None: + # This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds + class MixerModelOpenLM(MixerModel): + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): + assert input_ids is not None or inputs_embeds is not None + hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + return hidden_states + + + # This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel + class MambaLMHeadModelOpenLM(MambaLMHeadModel): + def __init__( + self, + config, + initializer_cfg=None, + device=None, + dtype=None, + ) -> None: + super().__init__(config, initializer_cfg, device, dtype) + d_model = config.d_model + n_layer = config.n_layer + vocab_size = config.vocab_size + ssm_cfg = config.ssm_cfg + rms_norm = config.rms_norm + residual_in_fp32 = config.residual_in_fp32 + fused_add_norm = config.fused_add_norm + factory_kwargs = {"device": device, "dtype": dtype} + self.backbone = MixerModelOpenLM( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + ssm_cfg=ssm_cfg, + rms_norm=rms_norm, + initializer_cfg=initializer_cfg, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, ) - - super().__init__() - self.vocab_size = params.vocab_size - self.seq_len = params.seq_len - self.dim = params.d_model - - self.model = MambaLMHeadModelOpenLM(params) - - def reset_parameters(self): - return - - def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): - logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs) - return logits, hidden_state, inference_params + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): + hidden_state = self.backbone(input_ids, inputs_embeds, inference_params) + lm_logits = self.lm_head(hidden_state) + return lm_logits, hidden_state, inference_params + + + class Mamba(nn.Module): + # Experimental architecture, please "pip install mamba-ssm" + # https://arxiv.org/abs/2312.00752 + def __init__(self, params): + if MambaLMHeadModel is None: + raise ImportError( + "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." + ) + + super().__init__() + self.vocab_size = params.vocab_size + self.seq_len = params.seq_len + self.dim = params.d_model + + self.model = MambaLMHeadModelOpenLM(params) + + def reset_parameters(self): + return + + def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): + logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs) + return logits, hidden_state, inference_params def create_model(args): From 8b5d641c7f20112c858a85b11e67e5f5e743a51d Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Apr 2024 17:02:33 -0700 Subject: [PATCH 19/33] Mamba import optional and its dependencies too --- open_lm/model.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/open_lm/model.py b/open_lm/model.py index 47db88c1..52440e26 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -579,6 +579,22 @@ def reset_parameters(self): def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs) return logits, hidden_state, inference_params +else: + class Mamba(nn.Module): + # Experimental architecture, please "pip install mamba-ssm" + # https://arxiv.org/abs/2312.00752 + def __init__(self, params): + raise ImportError( + "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." + ) + + def reset_parameters(self): + return + + def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): + raise ImportError( + "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." + ) def create_model(args): From 9525c5c87400778d1871b96ce7c7cbfcc8c72a8d Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Apr 2024 18:37:59 -0700 Subject: [PATCH 20/33] check for state_dict in checkpoint when no epoch --- open_lm/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/open_lm/main.py b/open_lm/main.py index d7247e19..8dbeb3d0 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -139,6 +139,8 @@ def load_model(args, model, different_seed=False, filter_keys=None): # loading a bare (model only) checkpoint for fine-tune or evaluation start_epoch, global_step = 0, 0 pretrained_seed = None + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") return start_epoch, global_step, pretrained_seed From 7c62e96dbcc14e667ac92a3aa46e750ca56b7013 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Thu, 25 Apr 2024 11:40:23 -0700 Subject: [PATCH 21/33] rotary inv_freq not a buffer anymore --- open_lm/positional_embedding/rotary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py index b48ed890..4837d2db 100644 --- a/open_lm/positional_embedding/rotary.py +++ b/open_lm/positional_embedding/rotary.py @@ -48,7 +48,7 @@ def __init__(self, dim_model: int, seq_len: int, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) self.dim_model = dim_model - self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2)) + self.inv_freq = torch.zeros(self.dim_model // 2) self._cos_cached = None self._sin_cached = None @@ -71,7 +71,7 @@ def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = Non if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: self._seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype)) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype).to(device)) emb = torch.cat((freqs, freqs), dim=-1).to(device) self._cos_cached = emb.cos()[None, :, None, :].to(dtype) From bc64b0ffcd30cd7ccbb58b6d00f1a35139dd2f4e Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Thu, 25 Apr 2024 11:54:15 -0700 Subject: [PATCH 22/33] rotary inv_freq not a buffer anymore -> remove it from state dicts when loading --- open_lm/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/open_lm/main.py b/open_lm/main.py index 8dbeb3d0..16d73218 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -121,6 +121,8 @@ def load_model(args, model, different_seed=False, filter_keys=None): # resuming a train checkpoint w/ epoch and optimizer state start_epoch = checkpoint["epoch"] sd = checkpoint["state_dict"] + # remove inv_freq from the state dict if it exists + sd = {k: v for k, v in sd.items() if "inv_freq" not in k} global_step = checkpoint.get("step", None) if next(iter(sd.items()))[0].startswith("module"): sd = {k[len("module.") :]: v for k, v in sd.items()} @@ -141,6 +143,7 @@ def load_model(args, model, different_seed=False, filter_keys=None): pretrained_seed = None if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] + checkpoint = {k: v for k, v in checkpoint.items() if "inv_freq" not in k} model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") return start_epoch, global_step, pretrained_seed From d2d08b68c9064fd8d942ec0784a10f43e0349158 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Thu, 25 Apr 2024 13:46:58 -0700 Subject: [PATCH 23/33] loosen panda dep --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 898d184f..b333a8ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ xformers>=0.0.22 tiktoken wandb webdataset -pandas==2.1.4 +pandas fsspec tqdm jsonlines From 3e0ce49d9e2253ca014f4eaea36652c73920df55 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Mon, 29 Apr 2024 17:38:33 -0700 Subject: [PATCH 24/33] avoid masking if the attention_mask is masking only the right padding --- .gitignore | 2 ++ open_lm/utils/transformers/hf_model.py | 29 ++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 24f90e56..d4cc71ab 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,5 @@ tests/assets/source_*/* secrets.env checkpoints/ experiments/ +external/ +preproc_data/ diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 032f0e1b..ae4990b4 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -10,6 +10,17 @@ import os +def is_attention_mask_right(attention_mask): + # Get the first zero index in each sequence + first_zero_index = torch.min(attention_mask, dim=1).indices + # Sum each sequence mask + sum_values = torch.sum(attention_mask, dim=1) + # Check if the sum of the mask is equal to the first zero index (meaning that the rest of the sequence after the first 0 is also 0) + is_valid_sequence = (sum_values % attention_mask.shape[1] == first_zero_index).all() + + return is_valid_sequence + + class OpenLMModel(PreTrainedModel): config_class = OpenLMConfig @@ -107,6 +118,14 @@ def forward( ```""" assert position_ids is None, "Position IDs are not supported" # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + + if is_attention_mask_right(attention_mask): + # The masking can be done on the loss only + loss_mask = attention_mask + attention_mask = None + else: + loss_mask = None + logits, _, past_key_values = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, @@ -118,10 +137,16 @@ def forward( if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + if loss_mask is not None: + shift_mask = loss_mask[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(shift_logits, shift_labels) + loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum() + else: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) return output From 741d71726d040425369c84edb6873415353553fb Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 30 Apr 2024 16:07:59 -0700 Subject: [PATCH 25/33] handle none mask in hf_model --- open_lm/utils/transformers/hf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index ae4990b4..5e1c6b7f 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -119,7 +119,7 @@ def forward( assert position_ids is None, "Position IDs are not supported" # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - if is_attention_mask_right(attention_mask): + if attention_mask is not None and is_attention_mask_right(attention_mask): # The masking can be done on the loss only loss_mask = attention_mask attention_mask = None From 5186928fac626a2208b36e08ebdcbcc7700fcca1 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 1 May 2024 12:53:42 -0700 Subject: [PATCH 26/33] Filter keys keywords in load model (remove some layers from the state dict based on names) --- open_lm/utils/transformers/hf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 5e1c6b7f..d77284dc 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -143,7 +143,7 @@ def forward( shift_mask = loss_mask[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) - loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum() + loss = loss[shift_mask.view(-1)].sum() / shift_mask.sum() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) From 9bac9844c705ffc122fe00be5eac5bddca87da3a Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 1 May 2024 13:08:23 -0700 Subject: [PATCH 27/33] Fix hf_model loss --- open_lm/utils/transformers/hf_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index d77284dc..b7891136 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -140,10 +140,11 @@ def forward( shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1).to(shift_logits.device) if loss_mask is not None: - shift_mask = loss_mask[..., 1:].contiguous() + shift_mask = loss_mask[..., :-1].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) - loss = loss[shift_mask.view(-1)].sum() / shift_mask.sum() + shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100) + loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) From b5b1cfbf506eb4cced7865448964de053672686e Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 1 May 2024 13:33:07 -0700 Subject: [PATCH 28/33] Prevent wrong config to fail --- open_lm/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 52440e26..8773eda0 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -67,8 +67,11 @@ def _rescan_model_configs(model_config_paths=None): for cf in config_files: with open(cf, "r") as f: - model_cfg = json.load(f) - _MODEL_CONFIGS[cf.stem] = model_cfg + try: + model_cfg = json.load(f) + _MODEL_CONFIGS[cf.stem] = model_cfg + except json.JSONDecodeError: + print(f"Error loading model config {cf}") _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} From f301fb6a6c33b545d071426a3762ddda3aa1107f Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Sat, 4 May 2024 13:33:16 -0700 Subject: [PATCH 29/33] fix mask shift --- open_lm/utils/transformers/hf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index b7891136..d50936ec 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -140,7 +140,7 @@ def forward( shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1).to(shift_logits.device) if loss_mask is not None: - shift_mask = loss_mask[..., :-1].contiguous() + shift_mask = loss_mask[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100) From a14133f4b075738861f98b6819161500958cd637 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Thu, 11 Jul 2024 17:33:26 +0000 Subject: [PATCH 30/33] remove xformers from requirements --- open_lm/attention.py | 5 ++++- open_lm/model.py | 5 ++++- requirements.txt | 1 - 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index e0e8aba5..97fe4353 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -2,7 +2,10 @@ import torch from torch.nn import functional as F -import xformers.ops as xops +try: + import xformers.ops as xops +except ImportError: + print("xops not installed. Will error when using xformers_attn or swiglu (vs swiglu_torch)") def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): diff --git a/open_lm/model.py b/open_lm/model.py index 8773eda0..6f4815ea 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -11,7 +11,10 @@ from torch import nn from torch.utils.checkpoint import checkpoint -import xformers.ops as xops +try: + import xformers.ops as xops +except ImportError: + print("xops not installed. Will error when using xformers_attn or swiglu (vs swiglu_torch)") from huggingface_hub import PyTorchModelHubMixin diff --git a/requirements.txt b/requirements.txt index b333a8ad..f6726894 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ torch -xformers>=0.0.22 tiktoken wandb webdataset From 3a3998ea0d05bc8bf16866bbd972746ad851a5db Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Thu, 11 Jul 2024 17:45:04 +0000 Subject: [PATCH 31/33] full purge --- open_lm/attention.py | 39 +-------------------------------------- open_lm/model.py | 11 +++-------- open_lm/params.py | 2 +- 3 files changed, 5 insertions(+), 47 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 97fe4353..51fa69de 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -2,10 +2,6 @@ import torch from torch.nn import functional as F -try: - import xformers.ops as xops -except ImportError: - print("xops not installed. Will error when using xformers_attn or swiglu (vs swiglu_torch)") def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): @@ -66,31 +62,6 @@ def apply_attention_mask_(bias, attention_mask, queries_dtype): bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True)) -def xformers_attn(queries, keys, values, is_causal, attention_mask=None): - # xformers assumes q, k, v are [batch, seq_len, heads, embed_dim] - # We assume that queries match the last part of the key / value sequences - # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) - # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() - # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - - # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. - # In this case, there is no notion of causal masking, so we can just set the mask to None. - # This is actually needed to get the desired behavior with seq_len=1. - bias = None - if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None: - bias = xops.LowerTriangularMask() - elif is_causal and (queries.shape[1] > 1 or attention_mask is not None): - # Build causal mask that assumes queries are in the end of the sequence. - batch, q_seq_len, heads, _ = queries.shape - k_seq_len = keys.shape[1] - bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) - if attention_mask is not None: - apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) - elif not is_causal and attention_mask is not None: - raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") - return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias) - - def torch_attn(queries, keys, values, is_causal, attention_mask=None): # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention @@ -199,15 +170,7 @@ def get_attn_func( alpha=None, ): if attn_name == "auto": - return xformers_attn if torch.cuda.is_available() else torch_attn - elif attn_name == "xformers_attn": - return xformers_attn - elif attn_name == "xformers_attn_variable_length": - # Upon changing the input sequence length, xformers attention changes - # the stride dimension of the output tensor. This makes future calls to - # .view() that collapses last two dimensions fail. One thus needs to - # call .contiguous() on the output tensor. [#188] - return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() + return torch_attn elif attn_name == "torch_attn": return torch_attn elif attn_name == "custom_attn": diff --git a/open_lm/model.py b/open_lm/model.py index 6f4815ea..a420e029 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -11,14 +11,9 @@ from torch import nn from torch.utils.checkpoint import checkpoint -try: - import xformers.ops as xops -except ImportError: - print("xops not installed. Will error when using xformers_attn or swiglu (vs swiglu_torch)") - from huggingface_hub import PyTorchModelHubMixin -from open_lm.attention import get_attn_func, xformers_attn, torch_attn +from open_lm.attention import get_attn_func, torch_attn from open_lm.norms import get_norm_class from open_lm.positional_embedding.head_rotary import HeadRotaryWithCast from open_lm.positional_embedding.rotary import RotaryWithCast @@ -94,7 +89,7 @@ class Params: post_embed_norm: bool = False weight_tying: bool = False norm_type: nn.Module = nn.LayerNorm - attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn + attn_func: Callable = torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 moe_capacity_factor: float = 1.25 @@ -270,7 +265,7 @@ def __init__(self, layer_id, args: Params): if args.ffn_type == "swiglu": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False) + self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False) elif args.ffn_type == "swiglu_torch": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) diff --git a/open_lm/params.py b/open_lm/params.py index 0a7a3f64..0747e529 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -106,7 +106,7 @@ def add_model_args(parser): "--attn-name", type=str, default="auto", - choices=["auto", "xformers_attn", "xformers_attn_variable_length", "torch_attn", "custom_attn"], + choices=["auto", "torch_attn", "custom_attn"], help="type of attention to use", ) parser.add_argument( From 87bbb852f85087f415049767b09d2f7f88a5e89a Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Sun, 14 Jul 2024 23:10:54 +0000 Subject: [PATCH 32/33] mbm configs --- open_lm/model_configs/11m.json | 9 +++++++++ open_lm/model_configs/154m.json | 9 +++++++++ open_lm/model_configs/1b.json | 9 +++++++++ open_lm/model_configs/411m.json | 9 +++++++++ open_lm/model_configs/79m.json | 9 +++++++++ open_lm/model_configs/7b.json | 9 +++++++++ 6 files changed, 54 insertions(+) create mode 100644 open_lm/model_configs/11m.json create mode 100644 open_lm/model_configs/154m.json create mode 100644 open_lm/model_configs/1b.json create mode 100644 open_lm/model_configs/411m.json create mode 100644 open_lm/model_configs/79m.json create mode 100644 open_lm/model_configs/7b.json diff --git a/open_lm/model_configs/11m.json b/open_lm/model_configs/11m.json new file mode 100644 index 00000000..22aa9d6c --- /dev/null +++ b/open_lm/model_configs/11m.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 192, + "n_layers": 8, + "n_heads": 4, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file diff --git a/open_lm/model_configs/154m.json b/open_lm/model_configs/154m.json new file mode 100644 index 00000000..8c72165a --- /dev/null +++ b/open_lm/model_configs/154m.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 576, + "n_layers": 24, + "n_heads": 8, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file diff --git a/open_lm/model_configs/1b.json b/open_lm/model_configs/1b.json new file mode 100644 index 00000000..0a97ddcc --- /dev/null +++ b/open_lm/model_configs/1b.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 2048, + "n_layers": 24, + "n_heads": 16, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file diff --git a/open_lm/model_configs/411m.json b/open_lm/model_configs/411m.json new file mode 100644 index 00000000..9c029ca7 --- /dev/null +++ b/open_lm/model_configs/411m.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 1024, + "n_layers": 24, + "n_heads": 8, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file diff --git a/open_lm/model_configs/79m.json b/open_lm/model_configs/79m.json new file mode 100644 index 00000000..80601e3f --- /dev/null +++ b/open_lm/model_configs/79m.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 512, + "n_layers": 8, + "n_heads": 4, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file diff --git a/open_lm/model_configs/7b.json b/open_lm/model_configs/7b.json new file mode 100644 index 00000000..5c7b0f51 --- /dev/null +++ b/open_lm/model_configs/7b.json @@ -0,0 +1,9 @@ +{ + "hidden_dim": 4096, + "n_layers": 32, + "n_heads": 32, + "seq_len": 1024, + "vocab_size": 66816, + "post_embed_norm": false, + "weight_tying": false +} \ No newline at end of file From 11d202846b8fb64f793548b97c08903b14e8e47d Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri <133716510+sedrick-keh-tri@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:11:26 -0700 Subject: [PATCH 33/33] Create open_lm_1b_swiglutorch.json --- open_lm/model_configs/open_lm_1b_swiglutorch.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 open_lm/model_configs/open_lm_1b_swiglutorch.json diff --git a/open_lm/model_configs/open_lm_1b_swiglutorch.json b/open_lm/model_configs/open_lm_1b_swiglutorch.json new file mode 100644 index 00000000..d2f9b814 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_swiglutorch.json @@ -0,0 +1,10 @@ +{ + "hidden_dim": 2048, + "n_layers": 24, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "ffn_type": "swiglu_torch" +}