diff --git a/open_lm/model.py b/open_lm/model.py index 8773eda0..a953f935 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -496,17 +496,16 @@ def create_params(args): moe_top_k=cfg.get("moe_top_k", args.moe_top_k), ) + if MambaLMHeadModel is not None: - # This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds + # This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds class MixerModelOpenLM(MixerModel): def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): assert input_ids is not None or inputs_embeds is not None hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds residual = None for layer in self.layers: - hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params - ) + hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) @@ -524,8 +523,7 @@ def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **k ) return hidden_states - - # This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel + # This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel class MambaLMHeadModelOpenLM(MambaLMHeadModel): def __init__( self, @@ -554,12 +552,12 @@ def __init__( residual_in_fp32=residual_in_fp32, **factory_kwargs, ) + def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs): hidden_state = self.backbone(input_ids, inputs_embeds, inference_params) lm_logits = self.lm_head(hidden_state) return lm_logits, hidden_state, inference_params - class Mamba(nn.Module): # Experimental architecture, please "pip install mamba-ssm" # https://arxiv.org/abs/2312.00752 @@ -582,7 +580,9 @@ def reset_parameters(self): def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs): logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs) return logits, hidden_state, inference_params + else: + class Mamba(nn.Module): # Experimental architecture, please "pip install mamba-ssm" # https://arxiv.org/abs/2312.00752 diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index d50936ec..8a464f0c 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -17,7 +17,7 @@ def is_attention_mask_right(attention_mask): sum_values = torch.sum(attention_mask, dim=1) # Check if the sum of the mask is equal to the first zero index (meaning that the rest of the sequence after the first 0 is also 0) is_valid_sequence = (sum_values % attention_mask.shape[1] == first_zero_index).all() - + return is_valid_sequence @@ -144,7 +144,7 @@ def forward( loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100) - loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum() + loss = loss[shift_mask.view(-1)].sum() / shift_mask.sum() else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels)