diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index c312b9b94351d2..550eeb7f9665e4 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,14 +44,22 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update = None + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) _CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" _CONFIG_FOR_DOC = "Mamba2Config" @@ -111,6 +119,17 @@ def segment_sum(input_tensor): return tensor_segsum +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + class Mamba2Cache: """ Arguments: @@ -120,51 +139,69 @@ class Mamba2Cache: device: torch.device Attributes: - seqlen_offset: int - dtype: torch.dtype - conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] - ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config. + n_groups: (`int`): + Model's number of groups taken from the config - similar to tensor parallel in Transformer. + state_size: (`int`): + Model's SSM state size taken from config. + num_heads: (`int`): + The number of heads used in the linear attention / SSM. + head_dim: (`int`): + The respective dimension of the heads used in the linear attention / SSM. + intermediate_size: (`int`): + Model's intermediate_size based on (expand * hidden_dim) from config. + conv_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. """ def __init__( self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): - self.seqlen_offset = 0 self.dtype = dtype self.conv_kernel_size = config.conv_kernel + self.n_groups = config.n_groups + self.state_size = config.state_size + self.num_heads = config.num_heads + self.head_dim = config.head_dim self.intermediate_size = int(config.expand * config.hidden_size) - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * config.state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype - ) - for i in range(config.num_hidden_layers) - } - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] + self.conv_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.intermediate_size + 2 * self.n_groups * self.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.num_heads, + self.head_dim, + self.state_size, + device=device, + dtype=dtype, + ) def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) return self.conv_states[layer_idx] + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + def reset(self): self.conv_states.zero_() self.ssm_states.zero_() @@ -269,19 +306,27 @@ def cuda_kernels_forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): - # set up dimensions for reshapes later + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size - d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads - - # getting projected states from cache if it exists - if cache_params is not None and cache_params.seqlen_offset > 0: - in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 - split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] - _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, cache_params.conv_states[self.layer_idx], @@ -295,8 +340,9 @@ def cuda_kernels_forward( [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - A = -torch.exp(self.A_log.float()) # (nheads,) + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) @@ -318,20 +364,18 @@ def cuda_kernels_forward( ) hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection out = self.out_proj(hidden_states)[:, None, ...] - # if no cache is found, calling the kernel + + # Fused calculations or step by step if no initialized cache is found else: - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + # 2-4. Fused kernel for conv1d, SSM, and the final projection if self.training and cache_params is None: - out, ssm_state = mamba_split_conv1d_scan_combined( + out = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, @@ -348,41 +392,50 @@ def cuda_kernels_forward( headdim=self.head_dim, ngroups=self.n_groups, norm_before_gate=False, - return_final_states=True, + return_final_states=False, **dt_limit_kwargs, ) else: - gate, hidden_states_B_C, time_step = torch.split( - projected_states, - [self.intermediate_size, self.conv_dim, self.num_heads], - dim=-1, + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - # 1D Convolution - if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] - ) # (B, L, self.d_inner + 2 * ngroups * d_state) + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) else: hidden_states_B_C = causal_conv1d_fn( x=hidden_states_B_C.transpose(1, 2), weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, - ).transpose(1, 2)[:, :seq_len] + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( hidden_states_B_C, [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + # 3. SSM transformation scan_output, ssm_state = mamba_chunk_scan_combined( hidden_states.view(batch_size, seq_len, -1, self.head_dim), - time_step, + dt, A, B.view(batch_size, seq_len, self.n_groups, -1), C.view(batch_size, seq_len, self.n_groups, -1), @@ -395,11 +448,16 @@ def cuda_kernels_forward( dt_softplus=True, **dt_limit_kwargs, ) + + # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection out = self.out_proj(scan_output) return out @@ -407,60 +465,64 @@ def cuda_kernels_forward( def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype - # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 - _, _, gate, hidden_states, dt = projected_states.split( + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.seqlen_offset > 0: - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) - conv_state = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation - dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt[:, 0, :][:, None, ...] dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # [num_heads] -> [num_heads, head_dim] dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) - dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) # [bsz, num_heads, head_dim, state_size] - dA = torch.exp(dt[..., None] * A) + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) # Discretize B # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> @@ -474,11 +536,12 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # Discretize x into dB # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = dB * hidden_states[..., None] + dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx ) # Subsequent output @@ -488,7 +551,7 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -505,9 +568,9 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, else: # begin ssd naive implementation without einsums dt = nn.functional.softplus(dt + self.dt_bias) - dt = torch.clamp(dt, self.time_step_min) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() - B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) @@ -522,7 +585,6 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # Rearrange into blocks/chunks hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] A = A.permute(0, 3, 1, 2) A_cumsum = torch.cumsum(A, dim=-1) @@ -531,45 +593,43 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # This is the analog of a causal mask L = torch.exp(segment_sum(A)) - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - # Step 2: Compute M, equivalent to applying attention mask to weights + # Compute M, equivalent to applying attention mask to weights M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] M = M_intermediate.sum(dim=-1) - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.seqlen_offset > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) else: previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) states, ssm_state = new_states[:, :-1], new_states[:, -1] - # Compute state -> output conversion per chunk + # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - # compute Yoff C_times_states = (C[..., None, :] * states[:, :, None, ...]) state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) y = Y_diag + Y_off # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) @@ -579,8 +639,10 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, if pad_size > 0: y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) + + # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) scan_output = self.norm(y, gate) @@ -916,9 +978,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if use_cache: - cache_params.seqlen_offset += inputs_embeds.shape[1] - hidden_states = self.norm_f(hidden_states) if output_hidden_states: @@ -975,10 +1034,6 @@ def prepare_inputs_for_generation( ): # Overwitten -- uses `cache_params` as opposed to `past_key_values` - if inputs_embeds is not None: - past_len = inputs_embeds.shape[1] + input_ids.shape[1] - else: - past_len = input_ids.shape[1] if use_cache: # `cache_position` should have been initialized in `generate` if cache_position is None: @@ -987,33 +1042,18 @@ def prepare_inputs_for_generation( "`model.generate`, you are responsible for passing in a valid `cache_position` if " "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" ) - # how do we detect that we are in decoding without cache? if cache_position[0] > 0: input_ids = input_ids[:, -1][..., None] - attention_mask = attention_mask[:, -1][..., None] + + if attention_mask is not None: + attention_mask = None else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation # will be applied when it is longer, so it will be equivalent to always have it match # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, past_len, device=input_ids.device) - # if the cache is not used, we also do have to extend the attention mask here - # TODO there is likely a cleverer way to do this - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) - cache_params = None - - if attention_mask.shape[1] < past_len: - # we have to update manually the attention mask if - # we are in decoding without cache - # and we don't have position_ids here - # TODO but we should be able to use cache_position though at a later time - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None and cache_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 9b3a9563b58ddc..c2ef68f2614ea5 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, Mamba2Config, is_torch_available from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device +from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -103,6 +104,10 @@ def prepare_config_and_inputs( ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + # Only left padding is valid + attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) + attention_mask[0, :1] = 0 + sequence_labels = None token_labels = None choice_labels = None @@ -118,7 +123,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -158,6 +163,56 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = {"input_ids": input_ids} return config, inputs_dict + def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args): + model = Mamba2Model(config=config) + model.to(torch_device) + model.eval() + + output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state + + outputs = model( + input_ids[:, :-1], + attention_mask=attention_mask[:, :-1], + use_cache=True, + cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device), + ) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model( + input_ids[:, -1:], + attention_mask=attention_mask[:, -1:], + use_cache=True, + cache_params=outputs.cache_params, + cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device), + ) + output_two = outputs.last_hidden_state + + self.parent.assertTrue( + torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3) + ) + + def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False): + model = Mamba2Model(config) + model.eval() + + if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()): + self.parent.skipTest( + "This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found." + ) + if torch_device != "cuda": + self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.") + + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + token_emb = model.embeddings(input_ids) + outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb) + outputs_slow = model.layers[0].mixer.torch_forward(token_emb) + + self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) + @unittest.skipIf( not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" @@ -184,6 +239,14 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) + def test_mamba2_caching(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba2_caching(*config_and_inputs) + + def test_mamba2_slow_vs_fast_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs) + def test_initialization(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -199,23 +262,6 @@ def test_initialization(self): def test_tied_weights_keys(self): pass - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_without_input_ids(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - def test_generate_from_inputs_embeds(self, _, num_beams): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_greedy_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_search_generate_dict_outputs_use_cache(self): - pass - @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass