From ee4a7ef0e4d4aa170b9a4f1f0bfd8c2f1469009c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 16:19:20 +0900 Subject: [PATCH] Style nd updates --- .../models/mamba/modeling_mamba.py | 64 +++++++------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index bdc39e10ee4eee..65f928f418599c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -349,20 +349,9 @@ def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torc expand = config.expand d_conv = config.conv_kernel - self.conv_states = { i: torch.zeros( - batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype - ) for i in range(config.num_hidden_layers)} - self.ssm_states = { i: torch.zeros( - batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype - )for i in range(config.num_hidden_layers)} + self.conv_states = { i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) for i in range(config.num_hidden_layers)} + self.ssm_states = { i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype)for i in range(config.num_hidden_layers)} - def update_conv_state(self, hidden_states): - self.conv_state.copy_(torch.roll(self.conv_state, shifts=-1, dims=-1)) # Update state (B D W) - self.conv_state[:, :, -1] = hidden_states - return self.conv_state - - def update_ssm_state(self, ssm_state): - self.ssm_state.copy_(ssm_state) class MambaSlowMixer(MambaMixer): @@ -391,34 +380,17 @@ def forward(self, hidden_states, inference_params=None): if inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) - conv_state[:, :, -1] = hidden_states[:,:,0] - # out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) - # return out, conv_state, ssm_state + conv_state[:, :, -1].copy_(hidden_states[:,:,0]) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1).unsqueeze(-1) else: - conv_state = hidden_states inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0))) - - ssm_state = inference_params.ssm_states[self.layer_idx] - - # conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - # conv_state[:, :, -1] = hidden_states - - # when you have the first iter, use conv_state - hidden_states = self.act(self.conv1d(conv_state)[..., :seq_len]) - - # x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - # if self.conv1d.bias is not None: - # x = x + self.conv1d.bias - # x = self.act(x).to(dtype=dtype) - + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1,2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj(time_step) - - # discrete_time_step = discrete_time_step.transpose(0,1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) @@ -429,6 +401,7 @@ def forward(self, hidden_states, inference_params=None): deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :] # 3.c perform the recurrence y ← SSM(A, B, C)(x) + ssm_state = inference_params.ssm_states[self.layer_idx] ys = [] for i in range(seq_len): ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) @@ -436,13 +409,11 @@ def forward(self, hidden_states, inference_params=None): y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1)) ys.append(y[:,:,0]) y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]) y = y * self.act(gate) # (B D) - # 4. Final linear projection attn_outputs = self.out_proj(y.transpose(1,2)) - return attn_outputs, conv_state, ssm_state + return attn_outputs, None, ssm_state class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -502,8 +473,18 @@ def _init_weights(self, module): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=self.config.initializer_range) - - + # + # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + # dt = torch.exp( + # torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + # + math.log(dt_min) + # ).clamp(min=dt_init_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + # inv_dt = dt + torch.log(-torch.expm1(-dt)) + # with torch.no_grad(): + # self.dt_proj.bias.copy_(inv_dt) + # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + # self.dt_proj.bias._no_reinit = True @dataclass @@ -690,16 +671,15 @@ def forward( hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) - # inference_params.update_conv_state(conv_state) - # inference_params.update_ssm_state(ssm_state) - inference_params.seqlen_offset += inputs_embeds.shape[1] inference_params.ssm_states[idx].copy_(ssm_state) + # inference_params.conv_states[idx].copy_(conv_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: all_last_states = all_last_states + (ssm_state,) + inference_params.seqlen_offset += inputs_embeds.shape[1] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,)