Skip to content

Commit

Permalink
Style nd updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Feb 15, 2024
1 parent e9a80ad commit ee4a7ef
Showing 1 changed file with 22 additions and 42 deletions.
64 changes: 22 additions & 42 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict, Any
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -349,20 +349,9 @@ def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torc
expand = config.expand
d_conv = config.conv_kernel

self.conv_states = { i: torch.zeros(
batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype
) for i in range(config.num_hidden_layers)}
self.ssm_states = { i: torch.zeros(
batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype
)for i in range(config.num_hidden_layers)}
self.conv_states = { i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) for i in range(config.num_hidden_layers)}
self.ssm_states = { i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype)for i in range(config.num_hidden_layers)}

def update_conv_state(self, hidden_states):
self.conv_state.copy_(torch.roll(self.conv_state, shifts=-1, dims=-1)) # Update state (B D W)
self.conv_state[:, :, -1] = hidden_states
return self.conv_state

def update_ssm_state(self, ssm_state):
self.ssm_state.copy_(ssm_state)

class MambaSlowMixer(MambaMixer):

Expand Down Expand Up @@ -391,34 +380,17 @@ def forward(self, hidden_states, inference_params=None):
if inference_params.seqlen_offset > 0:
conv_state = inference_params.conv_states[self.layer_idx]
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = hidden_states[:,:,0]
# out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
# return out, conv_state, ssm_state
conv_state[:, :, -1].copy_(hidden_states[:,:,0])
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1).unsqueeze(-1)
else:
conv_state = hidden_states
inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)))

ssm_state = inference_params.ssm_states[self.layer_idx]

# conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
# conv_state[:, :, -1] = hidden_states

# when you have the first iter, use conv_state
hidden_states = self.act(self.conv1d(conv_state)[..., :seq_len])

# x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
# if self.conv1d.bias is not None:
# x = x + self.conv1d.bias
# x = self.act(x).to(dtype=dtype)

hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
x_dbl = self.x_proj(hidden_states.transpose(1,2))
time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1)
discrete_time_step = self.dt_proj(time_step)

# discrete_time_step = discrete_time_step.transpose(0,1)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)

# 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N)
Expand All @@ -429,20 +401,19 @@ def forward(self, hidden_states, inference_params=None):
deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :]

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
ssm_state = inference_params.ssm_states[self.layer_idx]
ys = []
for i in range(seq_len):
ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :])
# [b, d, n] X [b, n] -> [b, d]
y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1))
ys.append(y[:,:,0])
y = torch.stack(ys, dim=-1) # shape (b, l, d)

y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None])
y = y * self.act(gate) # (B D)

# 4. Final linear projection
attn_outputs = self.out_proj(y.transpose(1,2))
return attn_outputs, conv_state, ssm_state
return attn_outputs, None, ssm_state

class MambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
Expand Down Expand Up @@ -502,8 +473,18 @@ def _init_weights(self, module):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)


#
# # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
# dt = torch.exp(
# torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
# + math.log(dt_min)
# ).clamp(min=dt_init_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
# inv_dt = dt + torch.log(-torch.expm1(-dt))
# with torch.no_grad():
# self.dt_proj.bias.copy_(inv_dt)
# # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
# self.dt_proj.bias._no_reinit = True


@dataclass
Expand Down Expand Up @@ -690,16 +671,15 @@ def forward(
hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params)
else:
hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params)
# inference_params.update_conv_state(conv_state)
# inference_params.update_ssm_state(ssm_state)
inference_params.seqlen_offset += inputs_embeds.shape[1]
inference_params.ssm_states[idx].copy_(ssm_state)
# inference_params.conv_states[idx].copy_(conv_state)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if output_attentions:
all_last_states = all_last_states + (ssm_state,)
inference_params.seqlen_offset += inputs_embeds.shape[1]

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down

0 comments on commit ee4a7ef

Please sign in to comment.