Skip to content

Commit

Permalink
black formatting and rebase on main
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Jul 24, 2024
1 parent b401e75 commit a1d6394
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,16 @@ def create_params(args):
moe_top_k=cfg.get("moe_top_k", args.moe_top_k),
)


if MambaLMHeadModel is not None:
# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds
# This is a copy-paste of the Mamba SSM code with the addition of inputs_embeds
class MixerModelOpenLM(MixerModel):
def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs):
assert input_ids is not None or inputs_embeds is not None
hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
Expand All @@ -524,8 +523,7 @@ def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **k
)
return hidden_states


# This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel
# This is a copy-paste of the Mamba SSM code with the usage of MixerModelOpenLM instead of MixerModel
class MambaLMHeadModelOpenLM(MambaLMHeadModel):
def __init__(
self,
Expand Down Expand Up @@ -554,12 +552,12 @@ def __init__(
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)

def forward(self, input_ids=None, inputs_embeds=None, inference_params=None, **kwargs):
hidden_state = self.backbone(input_ids, inputs_embeds, inference_params)
lm_logits = self.lm_head(hidden_state)
return lm_logits, hidden_state, inference_params


class Mamba(nn.Module):
# Experimental architecture, please "pip install mamba-ssm"
# https://arxiv.org/abs/2312.00752
Expand All @@ -582,7 +580,9 @@ def reset_parameters(self):
def forward(self, input_ids, inputs_embeds=None, inference_params=None, **kwargs):
logits, hidden_state, inference_params = self.model(input_ids, inputs_embeds, inference_params, **kwargs)
return logits, hidden_state, inference_params

else:

class Mamba(nn.Module):
# Experimental architecture, please "pip install mamba-ssm"
# https://arxiv.org/abs/2312.00752
Expand Down
4 changes: 2 additions & 2 deletions open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def is_attention_mask_right(attention_mask):
sum_values = torch.sum(attention_mask, dim=1)
# Check if the sum of the mask is equal to the first zero index (meaning that the rest of the sequence after the first 0 is also 0)
is_valid_sequence = (sum_values % attention_mask.shape[1] == first_zero_index).all()

return is_valid_sequence


Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)
loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum()
loss = loss[shift_mask.view(-1)].sum() / shift_mask.sum()
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
Expand Down

0 comments on commit a1d6394

Please sign in to comment.