From 8208f5c633394480160118d843ca65ad43580f5c Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 16 Jul 2024 20:17:48 +0000 Subject: [PATCH] Adda Zamba2 --- .gitignore | 1 + src/transformers/models/zamba2/__init__.py | 57 + src/transformers/models/zamba2/attention.py | 222 ++ .../models/zamba2/configuration_zamba2.py | 235 +++ src/transformers/models/zamba2/enums.py | 10 + src/transformers/models/zamba2/hf_utils.py | 15 + .../models/zamba2/mamba2_layer.py | 368 ++++ src/transformers/models/zamba2/mamba_block.py | 703 +++++++ .../models/zamba2/mamba_config.py | 136 ++ src/transformers/models/zamba2/mamba_layer.py | 308 +++ src/transformers/models/zamba2/mamba_model.py | 176 ++ src/transformers/models/zamba2/mlp.py | 102 + .../models/zamba2/modeling_zamba2.py | 1747 ++++++++++++++++ .../models/zamba2/ops/__init__.py | 0 .../zamba2/ops/selective_scan_interface.py | 361 ++++ .../models/zamba2/ops/triton/__init__.py | 0 .../models/zamba2/ops/triton/k_activations.py | 153 ++ .../models/zamba2/ops/triton/layer_norm.py | 1086 ++++++++++ .../models/zamba2/ops/triton/layernorm.py | 636 ++++++ .../zamba2/ops/triton/layernorm_gated.py | 437 ++++ .../ops/triton/selective_state_update.py | 263 +++ .../models/zamba2/ops/triton/ssd_bmm.py | 262 +++ .../zamba2/ops/triton/ssd_chunk_scan.py | 1825 +++++++++++++++++ .../zamba2/ops/triton/ssd_chunk_state.py | 866 ++++++++ .../models/zamba2/ops/triton/ssd_combined.py | 959 +++++++++ .../zamba2/ops/triton/ssd_state_passing.py | 348 ++++ .../models/zamba2/requirements.txt | 40 + src/transformers/models/zamba2/rotary.py | 76 + .../models/zamba2/selective_scan_interface.py | 389 ++++ src/transformers/models/zamba2/setup.py | 158 ++ src/transformers/models/zamba2/switch_mlp.py | 85 + src/transformers/models/zamba2/test.py | 28 + .../models/zamba2/transformers_tests.ipynb | 363 ++++ src/transformers/models/zamba2/utils.py | 71 + src/transformers/models/zamba2/zamba2_torch | 1 + .../zamba2/zamba2_torch_generation.ipynb | 790 +++++++ 36 files changed, 13277 insertions(+) create mode 100644 src/transformers/models/zamba2/__init__.py create mode 100644 src/transformers/models/zamba2/attention.py create mode 100644 src/transformers/models/zamba2/configuration_zamba2.py create mode 100644 src/transformers/models/zamba2/enums.py create mode 100644 src/transformers/models/zamba2/hf_utils.py create mode 100644 src/transformers/models/zamba2/mamba2_layer.py create mode 100644 src/transformers/models/zamba2/mamba_block.py create mode 100644 src/transformers/models/zamba2/mamba_config.py create mode 100644 src/transformers/models/zamba2/mamba_layer.py create mode 100644 src/transformers/models/zamba2/mamba_model.py create mode 100644 src/transformers/models/zamba2/mlp.py create mode 100644 src/transformers/models/zamba2/modeling_zamba2.py create mode 100755 src/transformers/models/zamba2/ops/__init__.py create mode 100755 src/transformers/models/zamba2/ops/selective_scan_interface.py create mode 100755 src/transformers/models/zamba2/ops/triton/__init__.py create mode 100755 src/transformers/models/zamba2/ops/triton/k_activations.py create mode 100755 src/transformers/models/zamba2/ops/triton/layer_norm.py create mode 100755 src/transformers/models/zamba2/ops/triton/layernorm.py create mode 100755 src/transformers/models/zamba2/ops/triton/layernorm_gated.py create mode 100755 src/transformers/models/zamba2/ops/triton/selective_state_update.py create mode 100755 src/transformers/models/zamba2/ops/triton/ssd_bmm.py create mode 100755 src/transformers/models/zamba2/ops/triton/ssd_chunk_scan.py create mode 100755 src/transformers/models/zamba2/ops/triton/ssd_chunk_state.py create mode 100755 src/transformers/models/zamba2/ops/triton/ssd_combined.py create mode 100755 src/transformers/models/zamba2/ops/triton/ssd_state_passing.py create mode 100644 src/transformers/models/zamba2/requirements.txt create mode 100644 src/transformers/models/zamba2/rotary.py create mode 100755 src/transformers/models/zamba2/selective_scan_interface.py create mode 100644 src/transformers/models/zamba2/setup.py create mode 100644 src/transformers/models/zamba2/switch_mlp.py create mode 100644 src/transformers/models/zamba2/test.py create mode 100644 src/transformers/models/zamba2/transformers_tests.ipynb create mode 100644 src/transformers/models/zamba2/utils.py create mode 160000 src/transformers/models/zamba2/zamba2_torch create mode 100644 src/transformers/models/zamba2/zamba2_torch_generation.ipynb diff --git a/.gitignore b/.gitignore index 337f2ef2c7..579c5323fe 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ *.py[cod] *$py.class +src/transformers/models/zamba2/csrc/ # C extensions *.so diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py new file mode 100644 index 0000000000..0cf260fec1 --- /dev/null +++ b/src/transformers/models/zamba2/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_zamba2": ["Zamba2Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_zamba2"] = [ + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_zamba2 import Zamba2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_zamba2 import ( + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + Zamba2Model, + Zamba2PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/zamba2/attention.py b/src/transformers/models/zamba2/attention.py new file mode 100644 index 0000000000..6793cd11af --- /dev/null +++ b/src/transformers/models/zamba2/attention.py @@ -0,0 +1,222 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Union +import transformer_engine as te + +import torch +from rotary import * +from enums import AttnMaskType + +class CausalSelfAttention(nn.Module): + + def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **kwargs): + super().__init__() + assert config.hidden_size % config.num_mem_heads == 0 + self.config = config + self.linear_qkv = nn.Linear(2 * config.hidden_size, 6 * config.hidden_size, bias=config.add_bias_linear) + self.linear_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=config.add_bias_linear) + self.n_head = config.num_mem_heads + self.n_embd = config.hidden_size * 2 + # For normal attention without groups, num_query_groups == num_attention_heads, + # so these two will be the same + self.query_projection_size = self.config.kv_channels * self.config.num_mem_heads + self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups + world_size = 1 + self.world_size = world_size + + self.hidden_size_per_attention_head = self.query_projection_size // self.config.num_attention_heads + self.num_attention_heads_per_partition = self.config.num_attention_heads + self.num_query_groups_per_partition = self.config.num_query_groups + self.dpa = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads, + kv_channels=self.config.kv_channels, + attention_dropout=0.0, + layer_number=layer_number, + attn_mask_type="causal" + ) + self.dpa_generation = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads, + kv_channels =self.config.kv_channels, + attention_dropout=0.0, layer_number=layer_number, + attn_mask_type="no_mask") + # only applying LoRA to QKV as those are accessible. The output proj shouldn't matter since we are also applying to the fc_in of the MLP so this shouldn't be any less expressive since no residual connection or layernorm in between + if self.config.use_shared_block_lora and 0 == 1: + self.linear_q_lora_A_list = nn.ParameterList([]) + self.linear_q_lora_B_list = nn.ParameterList([]) + self.linear_k_lora_A_list = nn.ParameterList([]) + self.linear_k_lora_B_list = nn.ParameterList([]) + self.linear_v_lora_A_list = nn.ParameterList([]) + self.linear_v_lora_B_list = nn.ParameterList([]) + + for i in range(self.num_mem_blocks): + # we store all loras in a list + linear_q_lora_A = nn.Linear(2 * self.config.hidden_size, self.config.lora_rank, bias = False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, 2 * self.query_projection_size, bias = False) + self.linear_q_lora_A_list.append(linear_q_lora_A) + self.linear_q_lora_B_list.append(linear_q_lora_B) + linear_k_lora_A = nn.Linear(self.config.hidden_size,self.config.lora_rank,bias = False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, 2 * self.kv_projection_size, bias = False) + self.linear_k_lora_A_list.append(linear_k_lora_A) + self.linear_k_lora_B_list.append(linear_k_lora_B) + linear_v_lora_A = nn.Linear(2 * self.config.hidden_size, self.config.lora_rank, bias = False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, 2 * self.kv_projection_size, bias = False) + self.linear_v_lora_A_list.append(linear_v_lora_A) + self.linear_v_lora_B_list.append(linear_v_lora_B) + + def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): + """Allocate memory to store kv cache during inference.""" + + return torch.empty( + inference_max_sequence_length, + batch_size, + self.num_query_groups_per_partition, + self.hidden_size_per_attention_head * 2, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb, layer_number): + """ + Saves the generated key and value tensors to the end of the buffers in inference_params. + Returns the full size keys and values from the provided inference_params, as well as + adjusted rotary_pos_emb. + + Returns a tuple: (key, value, rotary_pos_emb) + + """ + if inference_params is None: + return key, value, rotary_pos_emb + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_length = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, key.dtype + ) + inference_value_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, value.dtype + ) + inference_params.key_value_memory_dict[layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + is_first_step = True + else: + # Get the pre-allocated buffers for this layer + inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ + layer_number + ] + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key + + inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value + key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + + value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + return key, value, rotary_pos_emb + + def forward(self, hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, forward_layer_idx = None): + + qkv_out = self.linear_qkv(hidden_states) + + #qkv_out = qkv_out.permute(1,0,2) + # TODO FIX + # self.num_query_groups_per_partition = 16 + # self.num_attention_heads_per_partition = 16 + # self.hidden_size_per_attention_head = 90 + new_tensor_shape = qkv_out.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head * 2 + ), + ) + qkv_out = qkv_out.view(*new_tensor_shape) + + (query, key, value) = torch.split( + qkv_out, + [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head * 2 + ), + self.hidden_size_per_attention_head * 2, + self.hidden_size_per_attention_head * 2, + ], + dim=3, + ) + + + if self.config.use_shared_block_lora and 0 == 1: + new_lora_tensor_shape = new_tensor_shape[:-1] + (-1,) + linear_q_lora_A = self.linear_q_lora_A_list[forward_layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[forward_layer_idx] + q_lora = linear_q_lora_A(hidden_states) + q_lora = linear_q_lora_B(q_lora) + query = query + q_lora.view(new_lora_tensor_shape) + linear_k_lora_A = self.linear_k_lora_A_list[forward_layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[forward_layer_idx] + k_lora = linear_k_lora_A(hidden_states) + k_lora = linear_k_lora_B(k_lora) + key = key + k_lora.view(new_lora_tensor_shape) + linear_v_lora_A = self.linear_v_lora_A_list[forward_layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[forward_layer_idx] + v_lora = linear_v_lora_A(hidden_states) + v_lora = linear_v_lora_B(v_lora) + value = value + v_lora.view(new_lora_tensor_shape) + + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head * 2) + + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + key, value, rotary_pos_emb = self._adjust_key_value_for_inference( + inference_params, key, value, rotary_pos_emb, forward_layer_idx + ) + + if rotary_pos_emb is not None: + + q_pos_emb, k_pos_emb = rotary_pos_emb + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + + + if inference_params is None or inference_params.sequence_len_offset == 0: + y = self.dpa(query, key, value) + else: + y = self.dpa_generation(query, key, value) + + y = self.linear_proj(y) + + return y \ No newline at end of file diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py new file mode 100644 index 0000000000..1d06814abb --- /dev/null +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Zamba model configuration""" + +import math + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +import torch.nn.functional as F + + +logger = logging.get_logger(__name__) + + +class Zamba2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a + Zamba model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Zamba-v2 model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Zamba model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ZambaModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 3712): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14848): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 76): + Number of hidden layers in the model. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + n_mamba_heads (``, *optional*, defaults to 2): + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. + hidden_mamba_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the mamba layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + unk_token_id (``, *optional*, defaults to 0): + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + max_position_embeddings (`int`, *optional*, defaults to 4096): + This value doesn't have any real effect. The maximum sequence length that this model is intended to be + used with. It can be used with longer sequences, but performance may degrade. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attn_layer_period (`int`, *optional*, defaults to 6): + Once in this many layers, we will have a shared attention layer + attn_layer_offset (`int`, *optional*, defaults to 4): + Offset of the shared attention layer + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + mamba_d_state (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + rope_theta (``, *optional*, defaults to 10000): + + """ + + model_type = "zamba2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + tie_word_embeddings=True, + hidden_size=2560, + state_size=64, + conv_dimension=4, + expansion_factor=2, + use_low_rank_mamba_proj=False, + add_bias_linear=False, + mamba_headdim=64, + ffn_hidden_size=None, + gated_linear_unit=True, + bias_gelu_fusion=False, + lora_rank=128, + num_hidden_layers=54, + # num_attention_heads=32, + # num_key_value_heads=None, + activation_func=F.gelu, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + # max_position_embeddings=4096, + attention_dropout=0.0, + attn_layer_period=6, + attn_layer_offset=4, + num_mem_heads=32, + use_shared_block_lora=True, + use_mamba_kernels=True, + **kwargs, + ): + # mamba_d_state=16, + # mamba_d_conv=4, + # mamba_expand=2, + # mamba_dt_rank="auto", + # mamba_conv_bias=True, + # mamba_proj_bias=False, + # rope_theta=10000, + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + # self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + # self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.state_size = state_size + self.conv_dimension = conv_dimension + self.expansion_factor = expansion_factor + self.use_low_rank_mamba_proj = use_low_rank_mamba_proj + self.add_bias_linear = add_bias_linear + self.mamba_headdim = mamba_headdim + self.gated_linear_unit = gated_linear_unit + self.use_shared_block_lora = use_shared_block_lora + self.lora_rank = lora_rank + + # for backward compatibility + # if num_key_value_heads is None: + # num_key_value_heads = num_attention_heads + # self.num_key_value_heads = num_key_value_heads + + self.num_mem_heads = num_mem_heads + self.kv_channels = 2 * self.hidden_size // self.num_mem_heads + self.num_query_groups = 2 * self.hidden_size // self.num_mem_heads + # self.n_mamba_heads = n_mamba_heads + self.activation_func = activation_func + self.bias_gelu_fusion = bias_gelu_fusion + # self.hidden_mamba_act = hidden_mamba_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + if ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.attn_layer_period = attn_layer_period + self.attn_layer_offset = attn_layer_offset + self.num_mem_heads = num_mem_heads + + self.use_mamba_kernels = use_mamba_kernels + # self.mamba_d_state = mamba_d_state + # self.mamba_d_conv = mamba_d_conv + # self.mamba_expand = mamba_expand + # self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + # self.mamba_conv_bias = mamba_conv_bias + # self.mamba_proj_bias = mamba_proj_bias + + # self.rope_theta = rope_theta + + self.layers_block_type = self._layers_block_type(num_hidden_layers, attn_layer_period, attn_layer_offset) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _layers_block_type(self, num_hidden_layers, attn_layer_period, attn_layer_offset): + layers = [ + "mamba", + "mamba", + "attention+mamba", + ] + [ + "attention+mamba" if i % attn_layer_period == attn_layer_offset else "mamba" + for i in range(num_hidden_layers - 3) + ] + return layers \ No newline at end of file diff --git a/src/transformers/models/zamba2/enums.py b/src/transformers/models/zamba2/enums.py new file mode 100644 index 0000000000..aba6980b1e --- /dev/null +++ b/src/transformers/models/zamba2/enums.py @@ -0,0 +1,10 @@ +import enum + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 diff --git a/src/transformers/models/zamba2/hf_utils.py b/src/transformers/models/zamba2/hf_utils.py new file mode 100644 index 0000000000..a7e80b5bc9 --- /dev/null +++ b/src/transformers/models/zamba2/hf_utils.py @@ -0,0 +1,15 @@ +import json +import torch +import transformers +from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from transformers.utils.hub import cached_file + + +def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + + +def load_state_dict_hf(model_name, device="cpu"): + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) + return torch.load(resolved_archive_file, map_location=device) \ No newline at end of file diff --git a/src/transformers/models/zamba2/mamba2_layer.py b/src/transformers/models/zamba2/mamba2_layer.py new file mode 100644 index 0000000000..39df618278 --- /dev/null +++ b/src/transformers/models/zamba2/mamba2_layer.py @@ -0,0 +1,368 @@ +import math +from typing import Optional, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat +from mamba_config import MambaConfig + + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, SELECTIVE_SCAN_CUDA_IMPORT_FAILED + +# if SELECTIVE_SCAN_CUDA_IMPORT_FAILED: +# from ops.selective_scan_interface import selective_scan_ref +# selective_scan_fn = selective_scan_ref +# mamba_inner_fn = None + +try: + from ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + + +try: + from ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +from ops.triton.layernorm_gated import RMSNorm as RMSNormGated + +from ops.triton.ssd_combined import mamba_chunk_scan_combined +from ops.triton.ssd_combined import mamba_split_conv1d_scan_combined + + +class Mamba2Layer(nn.Module): + def __init__( + self, + config: MambaConfig, + conv_init=None, + d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP + A_init_range=(1, 16), + D_has_hdim=False, + rmsnorm=True, + norm_before_gate=False, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + conv_bias=True, + # Fused kernel and sharding options + chunk_size=256, + use_mem_eff_path=True, + layer_idx=None, # Absorb kwarg for general module + process_group=None, + sequence_parallel=True, + #device=None, + #dtype=None, + ): + #factory_kwargs = {"device": device, "dtype": dtype} + factory_kwargs = {} + super().__init__() + self.config = config + self.d_model = config.hidden_size + self.d_state = config.state_size + self.d_conv = config.conv_dimension + self.conv_init = conv_init + self.expand = config.expansion_factor + self.d_inner = (self.expand * self.d_model) + + assert self.d_inner == self.expand * self.d_model + self.headdim = config.mamba_headdim + self.d_ssm = self.d_inner if d_ssm is None else d_ssm + self.ngroups = 1 + assert self.d_ssm % self.headdim == 0 + self.nheads = self.d_ssm // self.headdim + self.D_has_hdim = D_has_hdim + self.rmsnorm = rmsnorm + self.norm_before_gate = norm_before_gate + self.dt_limit = dt_limit + self.activation = "silu" + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.layer_idx = layer_idx + + # Order: [z, x, B, C, dt] + d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + if self.config.use_low_rank_mamba_proj: + self.in_proj = nn.ModuleList([nn.Linear(self.d_model, self.config.mamba_lora_rank), nn.Linear(self.config.mamba_lora_rank, d_in_proj)]) + else: + self.in_proj = nn.ModuleList([nn.Linear(self.d_model, d_in_proj, bias=self.config.add_bias_linear, **factory_kwargs)]) + + conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=self.d_conv, + groups=conv_dim, + padding=self.d_conv - 1, + **factory_kwargs, + ) + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=self.in_proj[0].weight.dtype) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads)) + self.D._no_weight_decay = True + + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, + group_size=self.d_ssm // self.ngroups, **factory_kwargs) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=self.config.add_bias_linear, **factory_kwargs) + + + def forward(self, u, from_shared_proj=None, seqlen=None, seq_idx=None, inference_params=None): + """ + u: (batch, seqlen, hidden_dim) if seqlen=None. + If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we + split u during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + Returns: same shape as u + """ + + seqlen_og = seqlen + if seqlen is None: + batch, seqlen, dim = u.shape + else: + batch_seqlen, dim = u.shape + batch = batch_seqlen // seqlen + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.sequence_len_offset > 0: + # The states are updated inplace + out, _, _ = self.step(u, conv_state, ssm_state) + return out + + if from_shared_proj is not None and self.config.use_low_rank_mamba_proj: + zxbcdt = self.in_proj[1](self.in_proj[0](u)) # (B, L, d_in_proj) or (B * L, d_in_proj) + zxbcdt += from_shared_proj + else: + zxbcdt = self.in_proj[0](u) # (B, L, d_in_proj) or (B * L, d_in_proj) + if seqlen_og is not None: + zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) + A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) + dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + if self.use_mem_eff_path and inference_params is None: + out = mamba_split_conv1d_scan_combined( + zxbcdt, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.dt_bias, + A, + D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.rmsnorm else None, + rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.ngroups, + norm_before_gate=self.norm_before_gate, + **dt_limit_kwargs, + ) + if seqlen_og is not None: + out = rearrange(out, "b l d -> (b l) d") + else: + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, + [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], + dim=-1 + ) + if conv_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC, "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + assert self.activation in ["silu", "swish"] + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + xBC = self.act( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) + ) # (B, L, self.d_ssm + 2 * ngroups * d_state) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) + y = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + A, + rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b l h p -> b l (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([F.silu(z0) * x0, y], dim=-1) + if seqlen_og is not None: + y = rearrange(y, "b l d -> (b l) d") + out = self.out_proj(y) + + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + zxbcdt = self.in_proj[0](hidden_states.squeeze(1)) # (B 2D) + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, + [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], + dim=-1 + ) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + else: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) + A = -torch.exp(self.A_log.float()) # (nheads,) + + # SSM step + if selective_state_update is None: + assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" + # Discretize A and B + dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) + dA = torch.exp(dt * A) # (batch, nheads) + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update( + ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None, + dt_bias=dt_bias, dt_softplus=True + ) + y = rearrange(y, "b h p -> b (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([F.silu(z0) * x0, y], dim=-1) + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.in_proj[0].weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict_mamba: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj[0].weight.device, + dtype=self.in_proj[0].weight.dtype, + ) + inference_params.key_value_memory_dict_mamba[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict_mamba[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state \ No newline at end of file diff --git a/src/transformers/models/zamba2/mamba_block.py b/src/transformers/models/zamba2/mamba_block.py new file mode 100644 index 0000000000..32b4226d21 --- /dev/null +++ b/src/transformers/models/zamba2/mamba_block.py @@ -0,0 +1,703 @@ +import math +from typing import Optional, Union +import re +from contextlib import nullcontext +from abc import ABC, abstractmethod +from dataclasses import dataclass +import functools +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + +try: + from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +except ImportError: + selective_scan_fn, mamba_inner_fn = None, None + +try: + from ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + +from mamba_layer import MambaLayer +from mamba2_layer import Mamba2Layer +from mamba_config import MambaConfig +from mlp import MLP +from attention import CausalSelfAttention +from switch_mlp import SwitchMLP +from rotary import RotaryEmbedding + + +class MambaBlock(nn.Module): + def __init__( + self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + + super().__init__() + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(config) + # TODO -- this is a bit of a hack but might work for now + if config.use_module_layernorm and not config.rms_norm: + self.norm = norm_cls + else: + self.norm = norm_cls(config.hidden_size) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + assert config.num_mem_heads == 0, 'args.num_mem_heads > 0 only supports fused_add_norm=False' + if moe_cls is not None: + self.moe = moe_cls(config) + else: + self.moe = None + + def forward( + self, hidden_states: Tensor, from_shared_proj: Optional[Tensor] = None, from_tf: Optional[Tensor] = None, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None + ): + + if not self.fused_add_norm: + + residual = (hidden_states + residual) if residual is not None else hidden_states + if from_tf is not None: + hidden_states = self.norm((residual + from_tf).to(dtype=self.norm.weight.dtype)) + else: + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + + hidden_states = self.mixer(hidden_states, from_shared_proj=from_shared_proj, inference_params=inference_params) + + return hidden_states , residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + +class AttentionBlock(nn.Module): + def __init__( + self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + super().__init__() + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(config) + # TODO -- this is a bit of a hack but might work for now + if config.use_module_layernorm and not config.rms_norm: + self.norm = norm_cls + else: + self.norm = norm_cls(config.hidden_size) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + assert config.num_mem_heads == 0, 'args.num_mem_heads > 0 only supports fused_add_norm=False' + if moe_cls is not None: + self.moe = moe_cls(config) + else: + self.moe = None + + self.rotary_pos_emb = RotaryEmbedding( + config.kv_channels, rotary_percent=1.0, seq_len_interpolation_factor=None + ) + + def forward( + self, hidden_states: Tensor, from_tf: Optional[Tensor] = None, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None + ): + + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + if from_tf is not None: + hidden_states = self.norm((residual + from_tf).to(dtype=self.norm.weight.dtype)) + else: + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + + hidden_states = hidden_states.transpose(0,1).contiguous() + rotary_seq_len = hidden_states.shape[0] + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + hidden_states = self.mixer(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, inference_params=inference_params) + hidden_states = hidden_states.transpose(0,1) + return hidden_states , residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + +class Memory_AttentionBlock(nn.Module): + def __init__( + self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, residual_in_fp32=False, fused_add_norm=False + ): + super().__init__() + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.mixer = mixer_cls(config) + # TODO -- this is a bit of a hack but might work for now + assert config.rms_norm, 'Memory_AttentionBlock only supports RMSNorm' + self.norm = norm_cls(2 * config.hidden_size) + self.fused_add_norm = fused_add_norm + + if moe_cls is not None: + self.moe = moe_cls(config) + else: + self.moe = None + + + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None, rotary_pos_emb=None, forward_layer_idx = None + ): + + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + # We don't include the residual in this block. + # The residual is taken care of in vBlock. + # This approach is more convenient as input and output dimensions are different. + + hidden_states = hidden_states.transpose(0,1).contiguous() + + hidden_states = self.mixer(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, inference_params=inference_params, forward_layer_idx = forward_layer_idx) + + hidden_states = hidden_states.transpose(0,1) + + return hidden_states + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class MambaBlockParallelMoe(nn.Module): + def __init__( + self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(config) + # TODO -- this is a bit of a hack but might work for now + if config.use_module_layernorm and not config.rms_norm: + self.norm = norm_cls + self.norm_moe = norm_moe + else: + self.norm = norm_cls(config.hidden_size) + self.norm_moe = norm_moe(config.hidden_size) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + assert isinstance( + self.norm_moe, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + if moe_cls is not None: + self.moe = moe_cls(config) + else: + self.moe = None + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None + ): + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states_moe, _ = fused_add_norm_fn( + hidden_states, + self.norm_moe.weight, + self.norm_moe.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm_moe.eps, + ) + + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + + hidden_states_moe = self.moe(hidden_states_moe) + hidden_states += hidden_states_moe + return hidden_states , residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class MoEBlock(nn.Module): + def __init__( + self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, attention_mask=None, layer_idx=None + ): + super().__init__() + self.config = config + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(config, layer_idx=layer_idx) + # TODO -- this is a bit of a hack but might work for now + if config.use_module_layernorm and not config.rms_norm: + self.norm = norm_cls + else: + self.norm = norm_cls(config.hidden_size) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + if moe_cls is not None: + self.moe = moe_cls(config) + else: + self.moe = None + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, from_tf: Optional[Tensor] = None, inference_params=None, attention_mask=None, forward_layer_idx = None + ): + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + + hidden_states = self.mixer(hidden_states, forward_layer_idx = forward_layer_idx) + return hidden_states , residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + + +class vBlock(nn.Module): + def __init__( + self, config, sa_cls, mlp_cls=None, norm_cls=nn.LayerNorm, residual_in_fp32=False, layer_idx=None + ): + super().__init__() + self.use_mem_mlp = config.use_mem_mlp + self.sa = Memory_AttentionBlock(config, mixer_cls=sa_cls, norm_cls=norm_cls, residual_in_fp32=config.residual_in_fp32) + if config.use_mem_mlp: + self.mlp = MoEBlock(config, mixer_cls=mlp_cls, norm_cls=norm_cls, residual_in_fp32=config.residual_in_fp32, layer_idx=-1) + + def forward(self, hidden_states, residual=None, x_orig=None, inference_params=None, attention_mask=None, rotary_pos_emb=None, forward_layer_idx = None): + x = hidden_states + residual if residual is not None else hidden_states + x_ = torch.concatenate([x, x_orig], dim=-1).type(hidden_states.dtype) + x = self.sa(x_, inference_params=inference_params, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, forward_layer_idx = forward_layer_idx) + + if self.use_mem_mlp: + x, residual = self.mlp(x, forward_layer_idx = forward_layer_idx) + return x + + + +def count_mem_blocks_in_config(config): + num_gs = 0 + for val in config.layer_mapping: + if val == 'g': + num_gs +=1 + return num_gs + +# TODO -- figure out if we need to do layernorm via the spec or not +def create_block(config, layer_idx): + factory_kwargs = {} + + + if layer_idx == -1: + num_gs = count_mem_blocks_in_config(config) + norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon, dtype=torch.float32) + sa_cls = partial(CausalSelfAttention, **factory_kwargs, layer_number=-1, num_mem_blocks=num_gs) + mlp_cls = partial(MLP, layer_idx=layer_idx, **factory_kwargs, num_mem_blocks = num_gs) + block = vBlock( + config, + sa_cls=sa_cls, + mlp_cls=mlp_cls, + norm_cls=norm_cls, + residual_in_fp32=config.residual_in_fp32, + layer_idx=layer_idx + ) + else: + norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon) + + if (not config.layer_mapping) or config.layer_mapping[layer_idx-1][0] == 'r' or config.layer_mapping[layer_idx-1][0] == 'g': + if (not config.layer_mapping) or len(config.layer_mapping[layer_idx-1]) == 1: + if 'm' in config.layer_mapping: + mixer_cls = partial(Mamba2Layer, layer_idx=layer_idx, **factory_kwargs) + else: + mixer_cls = partial(MambaLayer, layer_idx=layer_idx, **factory_kwargs) + block = MambaBlock( + config, + mixer_cls=mixer_cls, + norm_cls=norm_cls, + fused_add_norm=config.fused_add_norm, + residual_in_fp32=config.residual_in_fp32, + ) + else: + if config.layer_mapping[layer_idx-1][1] == '1': + norm_moe = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon) + mixer_cls = partial(MambaLayer,layer_idx=layer_idx, **factory_kwargs) + moe_cls = partial(MLP,layer_idx=layer_idx, **factory_kwargs) + block = MambaBlockParallelMoe( + config, + mixer_cls=mixer_cls, + moe_cls=moe_cls, + norm_cls=norm_cls, + norm_moe=norm_moe, + fused_add_norm=config.fused_add_norm, + residual_in_fp32=config.residual_in_fp32, + ) + elif config.layer_mapping[layer_idx-1][0] == 'a': + mixer_cls = partial(CausalSelfAttention, layer_number=layer_idx, **factory_kwargs) + block = AttentionBlock( + config, + mixer_cls=mixer_cls, + norm_cls=norm_cls, + fused_add_norm=config.fused_add_norm, + residual_in_fp32=config.residual_in_fp32, + ) + elif config.layer_mapping[layer_idx-1][0] == 'm': + + mixer_cls = partial(Mamba2Layer, layer_idx=layer_idx, **factory_kwargs) + block = MambaBlock( + config, + mixer_cls=mixer_cls, + norm_cls=norm_cls, + fused_add_norm=config.fused_add_norm, + residual_in_fp32=config.residual_in_fp32, + ) + else: + mixer_cls = partial(SwitchMLP, layer_idx=layer_idx, **factory_kwargs) + block = MoEBlock( + config, + mixer_cls=mixer_cls, + norm_cls=norm_cls, + fused_add_norm=config.fused_add_norm, + residual_in_fp32=config.residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + +class MambaDecoder(nn.Module): + """Class wrapping a decoder stack of mamba blocks.""" + + def __init__( + self, + config: MambaConfig, + post_layer_norm=True, + pre_process=True, + post_process=True, + ): + super().__init__() + + self.config: MambaConfig = config + + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + + # required for pipeline parallel schedules + self.input_tensor = None + + self.checkpoint_core_block = self.config.recompute_granularity == 'selective' + + self.num_layers_per_pipeline_rank = ( + self.config.num_layers + ) + + self.layer_mapping = config.layer_mapping + + self.use_mem_rope = config.use_mem_rope + + + self._build_layers() + + def _build_layers(self): + num_layers_to_build = self.num_layers_per_pipeline_rank + # build the actual mamba layers + self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)]) + if self.config.num_mem_heads > 0: + blocks = [] + for _ in range(self.config.num_mem_blocks): + blocks.append(create_block(self.config, layer_idx=-1)) + self.blocks = torch.nn.ModuleList(blocks) + + # Replaced nn.Linear below with ColumnParallelLinear + # self.block_map = torch.nn.ModuleList([nn.Linear(self.config.hidden_size, self.config.hidden_size) if i%2 == 1 else nn.Identity() for i in range(self.config.num_layers)]) + self.block_map = torch.nn.ModuleList([ + nn.Linear(self.config.hidden_size, self.config.hidden_size, bias = self.config.add_bias_linear) if (i%2 == 1 if (self.layer_mapping is None) else self.layer_mapping[i] == 'g') else nn.Identity() for i in range(self.config.num_layers)]) + if self.use_mem_rope: + self.rotary_pos_emb = RotaryEmbedding( + 2 * self.config.hidden_size // self.config.num_mem_heads, rotary_percent=1.0, seq_len_interpolation_factor=None + ) + + if self.config.use_low_rank_mamba_proj: + blocks = [] + d_inner = self.config.expansion_factor * self.config.hidden_size + nheads = d_inner // self.config.mamba_headdim + d_in_proj = 2 * d_inner + 2 * self.config.mamba_ngroups * self.config.state_size + nheads + for _ in range(self.config.num_shared_mamba_proj): + blocks.append(nn.Linear(self.config.hidden_size, d_in_proj, bias = self.config.add_bias_linear)) + self.in_projs = torch.nn.ModuleList(blocks) + + + if self.post_process and self.post_layer_norm: + self.final_layernorm = RMSNorm(self.config.hidden_size, eps=self.config.layernorm_epsilon, dtype=torch.float32) + + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, residual, inference_params): + """Forward method with activation checkpointing.""" + + def custom(start, end): + def custom_forward(*args, **kwargs): + x, residual, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x, residual = layer(x, residual, *args, **kwargs) + return x, residual + + return custom_forward + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers_per_pipeline_rank: + hidden_states,residual = tensor_parallel.checkpoint( + custom(l, l + self.config.recompute_num_layers), + self.config.distribute_saved_activations, + hidden_states, + residual, + inference_params, + ) + + l += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers_per_pipeline_rank): + if l < self.config.recompute_num_layers: + hidden_states, residual = tensor_parallel.checkpoint( + custom(l, l + 1), + self.config.distribute_saved_activations, + hidden_states, + residual, + inference_params, + ) + else: + hidden_states = custom(l, l + 1)(hidden_states, residual, inference_params) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states, residual + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, residual = None, inference_params=None, attention_mask=None): + # hidden_states (float): [s, b, h] + # attention_mask (bool): [1, 1, s, s] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + + rng_context = nullcontext() + fp8_context = nullcontext() + + with rng_context and fp8_context: + residual = None + x_orig = torch.clone(hidden_states) + from_tf = None + block_count = 0 + rotary_pos_emb=None + if self.use_mem_rope: + if inference_params is not None and inference_params.sequence_len_offset > 0: + rotary_seq_len = inference_params.max_sequence_length + else: + rotary_seq_len = hidden_states.shape[1] + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + for i, layer in enumerate(self.layers): + if self.config.num_mem_heads > 0: + if (i%2 == 1 if (self.layer_mapping is None) else self.layer_mapping[i] == 'g'): + from_tf = self.block_map[i]( + self.blocks[block_count % self.config.num_mem_blocks]( + hidden_states, residual, x_orig, inference_params=inference_params, attention_mask = attention_mask, rotary_pos_emb=rotary_pos_emb, forward_layer_idx=block_count + ) + ) + block_count += 1 + else: + from_tf, _ = (None, None) + from_shared_proj = None + if self.config.use_low_rank_mamba_proj: + from_shared_proj = self.in_projs[i % self.config.num_shared_mamba_proj](hidden_states) + hidden_states, residual = layer( + hidden_states=hidden_states, + from_shared_proj=from_shared_proj, + from_tf=from_tf, + residual = residual, + inference_params=inference_params, + attention_mask = attention_mask, + ) + + + # Final layer norm. + if self.post_process and self.post_layer_norm: + if not self.config.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.final_layernorm.weight, + self.final_layernorm.bias, + eps=self.final_layernorm.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + + return hidden_states + + # def sharded_state_dict(self, prefix=''): + + # sharded_state_dict = {} + + # layer_prefix = f'{prefix}layers.' + # for layer in self.layers: + # sharded_state_dict.update(layer.sharded_state_dict(prefix=layer_prefix)) + + # if self.post_process and self.post_layer_norm: + # state_dict = self.state_dict(keep_vars=True) + + # tensor = state_dict['final_layernorm.weight'] + # layer_name = f'{prefix}final_layernorm.weight' + # sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(tensor, layer_name) + + # # RMSNorm doesn't have bias. + # if 'final_layernorm.bias' in state_dict.keys(): + # tensor = state_dict['final_layernorm.bias'] + # layer_name = f'{prefix}final_layernorm.bias' + # sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint( + # tensor, layer_name + # ) + + # return sharded_state_dict \ No newline at end of file diff --git a/src/transformers/models/zamba2/mamba_config.py b/src/transformers/models/zamba2/mamba_config.py new file mode 100644 index 0000000000..c00c834967 --- /dev/null +++ b/src/transformers/models/zamba2/mamba_config.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, field +from typing import Callable +from typing import List +import torch +import torch.nn.functional as F +from utils import init_method_normal, scaled_init_method_normal + +@dataclass +class MambaConfig(): + + # model architecture + base_model_type: str = "mamba" + num_layers: int = 0 + hidden_size: int = 0 + state_size: int = 0 + mamba_headdim: int = 64 + mamba_ngroups: int = 1 + expansion_factor: int = 2 + conv_dimension: int = 0 + conv_bias: bool = True + bias: bool = True + use_fast_path: bool = True + dt_rank: str = "auto" + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_init: str = "random" + dt_scale: float = 1.0 + dt_init_floor: float = 1e-4 + use_module_layernorm: bool = True # this is to initialize the layernorm via the layerspec vs directly as in the mamba code + + rms_norm: bool = False + fused_add_norm: bool = False + residual_in_fp32: bool = False + hidden_dropout: float = 0.0 + ffn_hidden_size: int = None + gated_linear_unit: bool = False + kv_channels: int = None + kv_mem_channels: int = None + num_attention_heads: int = 0 + num_query_groups: int = None + num_mem_query_groups: int = None + attention_dropout: float = 0.1 + num_mem_heads: int = 0 + use_mem_mlp: bool = False + window_size: int = None + gateconv_expansion_factor: int = 2 + layer_mapping: List[str] = field(default_factory=lambda: [""]) + vocab_size: int = 0 + device: str = "cuda" + use_mem_rope: bool = False + use_shared_block_lora: bool = False + lora_rank: int = 16 + num_mem_blocks: int = 1 + use_low_rank_mamba_proj: bool = False + num_shared_mamba_proj: int = 1 + mamba_lora_rank: int = 1 + + + + fp32_residual_connection: bool = False + layernorm_epsilon: float = 1e-5 + layernorm_zero_centered_gamma: bool = False + add_bias_linear: bool = False + activation_func: Callable = F.gelu + num_moe_experts: int = None + + # initialization + init_method: Callable = None + output_layer_init_method: Callable = None + init_method_std: float = 0.02 + + # mixed-precision + apply_query_key_layer_scaling: bool = True + attention_softmax_in_fp32: bool = True + + gated_linear_unit: bool = False + bias_gelu_fusion: bool = False + masked_softmax_fusion: bool = False + persist_layer_norm: bool = False + bias_dropout_fusion: bool = False + + # activation recomputation + recompute_granularity: str = None + recompute_method: str = None + recompute_num_layers: int = None + distribute_saved_activations: bool = None + + # # fp8 related + # fp8: str = None + # fp8_margin: int = 0 + # fp8_interval: int = 1 + # fp8_amax_history_len: int = 1 + # fp8_amax_compute_algo: str = "most_recent" + # fp8_wgrad: bool = True + + # normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm" + + def __post_init__(self): + + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + if self.kv_channels is None and self.num_attention_heads is not None: + self.kv_channels = self.hidden_size // self.num_attention_heads + + if self.kv_mem_channels is None and self.num_mem_heads > 0: + self.kv_mem_channels = self.hidden_size // self.num_mem_heads + + if self.num_query_groups is None: + self.num_query_groups = self.num_attention_heads + + if self.num_mem_query_groups is None: + self.num_mem_query_groups = self.num_mem_heads + + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + + if self.bias_gelu_fusion: + if not self.add_bias_linear: + raise ValueError( + "When bias_gelu_fusion is True, add_bias_linear must also be True." + ) + + if self.activation_func != F.gelu: + raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.') + + if self.init_method is None: + self.init_method = init_method_normal(self.init_method_std) + + if self.output_layer_init_method is None: + self.output_layer_init_method = scaled_init_method_normal( + self.init_method_std, self.num_layers + ) diff --git a/src/transformers/models/zamba2/mamba_layer.py b/src/transformers/models/zamba2/mamba_layer.py new file mode 100644 index 0000000000..bf3279b49a --- /dev/null +++ b/src/transformers/models/zamba2/mamba_layer.py @@ -0,0 +1,308 @@ +import math +from typing import Optional, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat +from mamba_config import MambaConfig + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, SELECTIVE_SCAN_CUDA_IMPORT_FAILED + +# if SELECTIVE_SCAN_CUDA_IMPORT_FAILED: +# from ops.selective_scan_interface import selective_scan_ref +# selective_scan_fn = selective_scan_ref +# mamba_inner_fn = None + +try: + from ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +class MambaLayer(nn.Module): + def __init__( + self, + config: MambaConfig, + layer_idx = 0 + ): + factory_kwargs = {} + super().__init__() + self.config = config + self.d_model = config.hidden_size + self.d_state = config.state_size + self.d_conv = config.conv_dimension + self.expand = config.expansion_factor + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if config.dt_rank == "auto" else config.dt_rank + self.use_fast_path = config.use_fast_path if mamba_inner_fn is not None else False + self.layer_idx = layer_idx + self.device = config.device + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.add_bias_linear, **factory_kwargs) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=config.conv_bias, + kernel_size=self.d_conv, + groups=self.d_inner, + padding=self.d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(config.dt_max) - math.log(config.dt_min)) + + math.log(config.dt_min) + ).clamp(min=config.dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=self.device), + "n -> d n", + d=self.d_inner + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=self.device)) # Keep in fp32 + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.add_bias_linear, **factory_kwargs) + + def forward(self, hidden_states, from_shared_proj = None, inference_params=None): + batch, seqlen, dim = hidden_states.shape + print("IN MAMBA LAYER FORWARD: ", hidden_states.shape) + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.sequence_len_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + print("XZ: ", xz.shape) + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + # y_0 = mamba_inner_fn( + # self.d_inner//2, + # xz[:,:self.d_inner,:], + # self.conv1d.weight[:self.d_inner//2], + # self.conv1d.bias[:self.d_inner//2] if self.conv1d.bias is not None else None, + # self.x_proj.weight[:,:self.d_inner//2], + # self.dt_proj.weight[:self.d_inner//2], + # #self.out_proj.weight, + # #self.out_proj.bias, + # A[:self.d_inner//2], + # None, # input-dependent B + # None, # input-dependent C + # self.D[:self.d_inner//2].float(), + # delta_bias=self.dt_proj.bias[:self.d_inner//2].float(), + # delta_softplus=True, + # ) + # y_1 = mamba_inner_fn( + # self.d_inner//2, + # xz[:,self.d_inner:,:], + # self.conv1d.weight[self.d_inner//2:], + # self.conv1d.bias[self.d_inner//2:] if self.conv1d.bias is not None else None, + # self.x_proj.weight[:,self.d_inner//2:], + # self.dt_proj.weight[self.d_inner//2:], + # #self.out_proj.weight, + # #self.out_proj.bias, + # A[self.d_inner//2:], + # None, # input-dependent B + # None, # input-dependent C + # self.D[self.d_inner//2:].float(), + # delta_bias=self.dt_proj.bias[self.d_inner//2:].float(), + # delta_softplus=True, + # ) + # y = torch.cat((y_0, y_1), dim=2) + y = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/src/transformers/models/zamba2/mamba_model.py b/src/transformers/models/zamba2/mamba_model.py new file mode 100644 index 0000000000..f7a825d97a --- /dev/null +++ b/src/transformers/models/zamba2/mamba_model.py @@ -0,0 +1,176 @@ +import logging +from typing import Literal, Optional, Union +import functools +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import math +import os +from mamba_block import MambaBlock, MambaDecoder +from mamba_config import MambaConfig +from hf_utils import * +import os, json +from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from transformers.utils.hub import cached_file + +def _init_weights( + module, + n_layer, + initializer_range=0.02, + rescale_prenorm_residual=True, + n_residuals_per_layer=1, +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MambaModel(nn.Module): + def __init__( + self, + config: MambaConfig, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = True, + initializer_cfg = None, + ) -> None: + super().__init__() + + self.config: MambaConfig = config + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + + if self.pre_process: + self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + self.decoder = MambaDecoder( + config = self.config, + pre_process = self.pre_process, + post_process = self.post_process, + ) + #if post_process: + # self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear) + # if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): + # self.initialize_last_stage_with_word_embeddings() + + self.apply( + partial( + _init_weights, + n_layer=self.config.num_layers, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + def initialize_last_stage_with_word_embeddings(self): + with torch.no_grad(): + self.output_layer.weight = self.embedding.weight + + def forward( + self, + input_ids, + position_ids = None, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params=None, + ) -> Tensor: + + if decoder_input is not None: + pass + elif self.pre_process: + + decoder_input = self.embedding(input_ids) + + decoder_input = decoder_input.permute(1,0,2) + else: + decoder_input = None + + + hidden_states = self.decoder( + hidden_states=decoder_input, + residual=None, + inference_params=inference_params, + ) + + + if not self.post_process: + return hidden_states + + #logits = self.output_layer(hidden_states) + + logits = hidden_states @ self.embedding.weight.T + return logits.contiguous() + + @classmethod + def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs): + if pretrained_model_name is not None: + json_config = load_config_hf(pretrained_model_name) + loaded = load_state_dict_hf(pretrained_model_name) + elif checkpoint_name is not None and config_name is not None: + with open(config_name, 'r') as f: + jsonstr = f.read() + json_config = json.loads(jsonstr) + state_dict = torch.load(checkpoint_name, map_location="cpu") + else: + raise ValueError("Must provide either a valid HF identifier or path to pytorch checkpoint") + + # process the state dict + state_dict["model"]["embedding.weight"] = state_dict["model"]["embedding.word_embeddings.weight"] + del state_dict["model"]["embedding.position_embeddings.weight"] + del state_dict["model"]["embedding.word_embeddings.weight"] + for key in list(state_dict["model"].keys()): + if "extra_state" in key: + del state_dict["model"][key] + state_dict["model"]["output_layer.weight"] = state_dict["model"]["embedding.weight"] + + config = MambaConfig( + num_layers = json_config["num_layers"], + hidden_size = json_config["hidden_size"], + state_size = json_config["state_size"], + conv_dimension = json_config["conv_dimension"], + expansion_factor = json_config["expansion_factor"], + rms_norm = json_config["rms_norm"], + use_mem_mlp = json_config["use_mem_mlp"], + num_mem_heads = json_config["num_mem_heads"], + num_attention_heads = json_config["num_attention_heads"], + vocab_size = json_config["padded_vocab_size"], + layer_mapping = json_config["layer_mapping"], + add_bias_linear=json_config["add_bias_linear"], + gated_linear_unit=True + ) + + model = MambaModel(config = config, max_sequence_length = 4096) + model.load_state_dict(state_dict["model"]) + return model + + def save_pretrained(self, save_directory): + # Ensure save_directory exists + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Save the model's state_dict + model_path = os.path.join(save_directory, 'pytorch_model.bin') + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, 'config.json') + with open(config_path, 'w') as f: + json.dump(self.config.__dict__, f) \ No newline at end of file diff --git a/src/transformers/models/zamba2/mlp.py b/src/transformers/models/zamba2/mlp.py new file mode 100644 index 0000000000..505a0d054a --- /dev/null +++ b/src/transformers/models/zamba2/mlp.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Union + +import torch +import torch.nn.functional as F +import torch.nn as nn +from utils import bias_gelu_impl +from mamba_config import MambaConfig + +class MLP(nn.Module): + + def __init__(self, config: MambaConfig,is_expert: bool = False, layer_idx=None, num_mem_blocks = None): + super().__init__() + + self.num_mem_blocks = num_mem_blocks + + self.config: MambaConfig = config + self.layer = layer_idx + ffn_hidden_size_1 = self.config.ffn_hidden_size + ffn_hidden_size_2 = self.config.ffn_hidden_size + # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf + if self.config.gated_linear_unit: + ffn_hidden_size_1 *= 2 + + if self.layer == -1: + ffn_hidden_size_1 = 8 * self.config.hidden_size + + self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear) + if self.config.gated_linear_unit or self.layer == -1: + + def glu(x): + x = torch.chunk(x, 2, dim=-1) + # x_ = torch.chunk(x, 4, dim=-1) + # x = (torch.cat((x_[0], x_[2]), dim=-1), torch.cat((x_[1], x_[3]), dim=-1)) + + return self.config.activation_func(x[0]) * x[1] + self.activation_func = glu + else: + print('going here instead') + self.activation_func = self.config.activation_func + + + self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear) + + # only lora the FC in here I think fc out already has a trainable projection + if self.config.use_shared_block_lora: + self.linear_fc1_lora_A_list = nn.ParameterList([]) + self.linear_fc1_lora_B_list = nn.ParameterList([]) + for i in range(self.num_mem_blocks): + #linear_fc1_lora_A = nn.Parameter(torch.randn(self.config.hidden_size, self.config.lora_rank)) + linear_fc1_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + #linear_fc1_lora_B = nn.Parameter(torch.randn(self.config.lora_rank, ffn_hidden_size_1)) + linear_fc1_lora_B = nn.Linear(self.config.lora_rank, ffn_hidden_size_1, bias = False) + self.linear_fc1_lora_A_list.append(linear_fc1_lora_A) + self.linear_fc1_lora_B_list.append(linear_fc1_lora_B) + + def forward(self, hidden_states, inference_params=None, forward_layer_idx = None): + + # [s, b, 4 * h/p] + #print("IN MLP FORWARD: ", hidden_states.shape) + if self.config.use_shared_block_lora: + linear_fc1_lora_A = self.linear_fc1_lora_A_list[forward_layer_idx] + linear_fc1_lora_B = self.linear_fc1_lora_B_list[forward_layer_idx] + lora_output = linear_fc1_lora_A(hidden_states) + lora_output= linear_fc1_lora_B(lora_output) + intermediate_parallel = self.linear_fc1(hidden_states) + intermediate_parallel = intermediate_parallel + lora_output + else: + intermediate_parallel= self.linear_fc1(hidden_states) + + #print("INTERMEDIATE PARALLEL: ", intermediate_parallel.shape) + + if self.config.bias_gelu_fusion: + assert self.config.add_bias_linear is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel) + else: + print(self.activation_func) + print(intermediate_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + #print("intermediate parallel prior fc2: ", intermediate_parallel.shape) + #print("fc2_weight: ", self.linear_fc2.weight.shape) + #print("intermediate_parallel after func", intermediate_parallel.shape) + output = self.linear_fc2(intermediate_parallel) + #print("MLP OUT: ", output.shape) + + return output + + def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): + sharded_key_prefix = prefix if sharded_key_prefix is None else sharded_key_prefix + sharded_state_dict = {} + for name, module in self._modules.items(): + sub_sd = module.sharded_state_dict( + prefix=f'{prefix}{name}.', + sharded_key_prefix=f'{sharded_key_prefix}{name}.', + sharded_offsets=sharded_offsets, + ) + sharded_state_dict.update(sub_sd) + return sharded_state_dict \ No newline at end of file diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py new file mode 100644 index 0000000000..9bb0e1a4d6 --- /dev/null +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -0,0 +1,1747 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Zamba model.""" + +import inspect +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_mamba_ssm_available, +) +from .configuration_zamba2 import Zamba2Config +from .mamba2_layer import Mamba2Layer +from .attention import CausalSelfAttention +from .mlp import MLP + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Zamba2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ZambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def count_mem_blocks_in_config(config): + num_gs = 0 + for val in config.layers_block_type: + if val == 'attention+mamba': + num_gs +=1 + return num_gs + +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + self.n_mamba_heads, + intermediate_size // self.n_mamba_heads, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + if self.layers_block_type[i] == "attention+mamba": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# The input dimension here is twice the hidden_size, and head_dim = 2 * hidden_size // num_heads. +# The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer +# (see fig. 2 in https://arxiv.org/pdf/2405.16712). +# Additionally, replaced +# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with +# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) +class Zamba2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = 2 * self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + # self.max_position_embeddings = config.max_position_embeddings + # self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != 2 * self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(2 * self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(2 * self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(2 * self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, 2 * self.hidden_size) + + attn_output = attn_output + attn_output = self.o_proj(attn_output) + attn_output = attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward +# dropped use_sliding_windows from the arguments of self._flash_attention_forward +class Zamba2FlashAttention2(Zamba2Attention): + """ + Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + softmax_scale = 1 / (query_states.shape[-1] / 2) ** 0.5 + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=softmax_scale, + ) + + attn_output = attn_output.reshape(bsz, q_len, 2 * self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention +class Zamba2SdpaAttention(Zamba2Attention): + """ + Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + softmax_scale = 1 / (query_states.shape[-1] / 2) ** 0.5 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + scale=softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, 2 * self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +ZAMBA2_ATTENTION_CLASSES = { + "eager": Zamba2Attention, + "flash_attention_2": Zamba2FlashAttention2, + "sdpa": Zamba2SdpaAttention, +} + + +class Zamba2MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Zamba2Config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.n_mamba_heads = config.n_mamba_heads + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.intermediate_size, + padding=self.conv_kernel_size - 1, + ) + + self.activation = config.hidden_mamba_act + self.act = ACT2FN[config.hidden_mamba_act] + + self.use_fast_kernels = config.use_mamba_kernels + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj_weight = nn.Parameter( + ( + torch.rand( + self.n_mamba_heads, + self.time_step_rank + self.ssm_state_size * 2, + self.intermediate_size // self.n_mamba_heads, + ) + - 0.5 + ) + * 2 + / self.intermediate_size**0.5 + ) + # time step projection (discretization) + self.dt_proj_weight = nn.Parameter( + (torch.rand(self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.time_step_rank) - 0.5) + * 2 + / self.time_step_rank**0.5 + ) # (h d dt_rank) + self.dt_proj_bias = nn.Parameter( + torch.zeros(self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads) + ) # (h d) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + self.A_log = nn.Parameter( + torch.log(A).reshape(self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, -1) + ) + self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + ) + + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None): + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + + # 1. Gated linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2) + hidden_states = hidden_states.squeeze(2).contiguous() + gate = gate.squeeze(2) + gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) # (h b d l) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if use_precomputed_states: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + + hidden_states = hidden_states.reshape( + -1, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, seq_len + ).transpose(0, 1) # (h b d l) + ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) # (h b l d) + + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) # (h b l d) + + discrete_time_step = self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2) # (h b d l) + + A = -torch.exp(self.A_log.float()) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj_bias.float() if self.dt_proj_bias is not None else None + scan_outputs = torch.empty((batch_size, 0, seq_len), device=hidden_states.device, dtype=hidden_states.dtype) + + if use_precomputed_states: + for n in range(self.n_mamba_heads): + scan_outputs_ = selective_state_update( + cache_params.ssm_states[self.layer_idx][:, n], + hidden_states[n, ..., 0], + discrete_time_step[n, ..., 0], + A[n], + B[n, :, 0], + C[n, :, 0], + self.D[n], + gate[n, ..., 0], + time_proj_bias[n], + dt_softplus=True, + ).unsqueeze(-1) + scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1) + + else: + ssm_state = torch.empty( + (batch_size, 0, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + for n in range(self.n_mamba_heads): + scan_outputs_, ssm_state_ = selective_scan_fn( + hidden_states[n], + discrete_time_step[n], + A[n], + B[n].transpose(1, 2), + C[n].transpose(1, 2), + self.D[n].float(), + gate[n], + time_proj_bias[n], + delta_softplus=True, + return_last_state=True, + ) + scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous() + ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # (b d l) + + hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2) + hidden_states = hidden_states.squeeze(2).contiguous() + gate = gate.squeeze(2) + gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) # (h b d l) + + use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) + # 2. Convolution sequence transformation + if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: + if self.training: + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = cache_params.ssm_states[self.layer_idx] + + ssm_state = ssm_state.to(hidden_states.device) + + if ( + cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] == batch_size + ): + conv_state = cache_params.conv_states[self.layer_idx] # (b d conv_size) + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # (b d 1) : decoding + else: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l) + else: + ssm_state = torch.zeros( + (batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size), + device=hidden_states.device, + dtype=dtype, + ) # (b h d l) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + hidden_states = hidden_states.reshape( + -1, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, seq_len + ).transpose(0, 1) # (h b d l) + ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) # (h b l d) + + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) # (h b l d) + + discrete_time_step = (self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2)) + self.dt_proj_bias[ + :, None, :, None + ] # (h b d l) + + discrete_time_step = nn.functional.softplus(discrete_time_step) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) + discrete_A = torch.exp(A[:, None, :, None, :] * discrete_time_step[:, :, :, :, None]) # (h b d l state) + discrete_B = discrete_time_step[:, :, :, :, None] * B[:, :, None, :, :].float() # (h b d l state) + deltaB_u = discrete_B * hidden_states[:, :, :, :, None].float() # (h b d l state) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose( + 0, 1 + ) # (h b d state) + scan_output = torch.matmul(ssm_state.transpose(0, 1).to(dtype), C[:, :, i, :].unsqueeze(-1)) # (h b d 1) + scan_outputs.append(scan_output[:, :, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # (h b d l) + scan_output = scan_output + (hidden_states * self.D[:, None, :, None]) # (h b d l) + scan_output = scan_output * self.act(gate) + + if use_cache: + cache_params.ssm_states[self.layer_idx] = ssm_state + + # 4. Final linear projection + contextualized_states = self.out_proj( + scan_output.transpose(0, 1).reshape(batch_size, -1, seq_len).transpose(1, 2) + ) # (b l d) + return contextualized_states + + def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None): + if self.use_fast_kernels: + if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type: + raise ValueError( + "Fast Mamba kernels are not available. Make sure to they are installed and that " + "the mamba module is on a CUDA device. lease run 'pip install causal-conv1d>=1.2.0' " + "and 'pip install mamba-ssm', or set use_fast_kernels=False in the model's config." + ) + return self.cuda_kernels_forward(hidden_states, cache_params) + return self.slow_forward(hidden_states, cache_params) + + +class Zamba2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x)) + + +class Zamba2AttentionDecoderLayer(nn.Module): + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + super().__init__() + assert config._attn_implementation != "flash_attention_2", ( + "Flash attention 2 is currently " "not supported in the HuhhingFace implementation of Zamba v1." + ) + num_gs = count_mem_blocks_in_config(config) + self.self_attn = CausalSelfAttention(config, layer_number=-1) + self.feed_forward = MLP(config, layer_idx=-1, num_mem_blocks = num_gs) + self.input_layernorm = Zamba2RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # The argument original_hidden_states is concatenated with hidden_states (which is the output of the previous (mamba) layer) + # The concatenated tensor is then used as input of the pre-attention RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states.transpose(0,1).contiguous() + hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, inference_params=past_key_value, forward_layer_idx = layer_idx) + hidden_states = hidden_states.transpose(0,1) + + self_attn_weights, present_key_value = None, None + # feed-forward (MLP) + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, forward_layer_idx=layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Zamba2MambaDecoderLayer(nn.Module): + def __init__(self, config: Zamba2Config, layer_idx: int): + super().__init__() + factory_kwargs = {} + self.mamba = Mamba2Layer(config=config, layer_idx=layer_idx, **factory_kwargs) + # self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) + self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + transformer_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712). + hidden_states = ( + hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba( + u=hidden_states, + inference_params=past_key_value, + ) + + self_attn_weights = None + + # residual connection after mamba + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +ZAMBA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ZambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zamba Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2PreTrainedModel(PreTrainedModel): + config_class = Zamba2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +ZAMBA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zamba Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2Model(Zamba2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + + Args: + config: ZambaConfig + """ + + def __init__(self, config: Zamba2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config),Zamba2AttentionDecoderLayer(config)]) + mamba_layers = [] + linear_layers = [] + self.layers_block_type = config.layers_block_type + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "mamba": + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) + elif config.layers_block_type[i] == "attention+mamba": + linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) + self.mamba_layers = nn.ModuleList(mamba_layers) + self.linear_layers = nn.ModuleList(linear_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + original_hidden_states = torch.clone(inputs_embeds) + # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer + + if use_cache and past_key_values is None: + logger.warning_once( + "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + mamba_layers = iter(self.mamba_layers) + linear_layers = iter(self.linear_layers) + block_count = 0 + + for layer_idx, layer_type in enumerate(self.layers_block_type): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if layer_type == "attention+mamba": + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + self.blocks[block_count % 2].__call__, + hidden_states, + original_hidden_states, + block_count, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + block_count, + ) + else: + layer_outputs = self.blocks[block_count % 2]( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=block_count, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + block_count += 1 + transformer_hidden_states = layer_outputs[0] + if output_attentions: + if layer_outputs[1] is not None: + all_self_attns += (layer_outputs[1],) + if self.gradient_checkpointing and self.training: + transformer_hidden_states = self._gradient_checkpointing_func( + next(linear_layers).__call__, + transformer_hidden_states, + ) + else: + transformer_hidden_states = next(linear_layers)( + transformer_hidden_states, + ) + else: + transformer_hidden_states = None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + next(mamba_layers).__call__, + hidden_states, + transformer_hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = next(mamba_layers)( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class Zamba2ForCausalLM(Zamba2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Zamba2Config): + super().__init__(config) + + self.model = Zamba2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ZambaForCausalLM + + >>> model = ZambaForCausalLM.from_pretrained("zyphra/Zamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("zyphra/Zamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + **kwargs, + ): + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1] + max_cache_length = self.config.sliding_window + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and past_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The Zamba Model with a sequence classification head on top (linear layer). + + [`ZambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + ZAMBA2_START_DOCSTRING, +) +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Zamba, MIXTRAL->ZAMBA +class Zamba2ForSequenceClassification(Zamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Zamba2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/src/transformers/models/zamba2/ops/__init__.py b/src/transformers/models/zamba2/ops/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/src/transformers/models/zamba2/ops/selective_scan_interface.py b/src/transformers/models/zamba2/ops/selective_scan_interface.py new file mode 100755 index 0000000000..df96b8da45 --- /dev/null +++ b/src/transformers/models/zamba2/ops/selective_scan_interface.py @@ -0,0 +1,361 @@ +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +global SELECTIVE_SCAN_CUDA_IMPORT_FAILED +SELECTIVE_SCAN_CUDA_IMPORT_FAILED = False + +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn = None + causal_conv1d_cuda = None + SELECTIVE_SCAN_CUDA_IMPORT_FAILED = True + print("SELECTIVE SCAN CUDA IMPORT FAILED: ", e) + +import selective_scan_cuda + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + """ + xz: (batch, dim, seqlen) + """ + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + assert checkpoint_lvl in [0, 1] + print("IN SELECTIVE SCNA: ", xz) + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) + if out_proj_bias is not None else None) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + x, z = xz.chunk(2, dim=1) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, + delta_proj_weight, out_proj_weight, conv1d_out, delta, + A, B, C, D, delta_bias, scan_intermediates, out) + return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + if dout.stride(-1) != 1: + dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), + "d (b l) -> b d l", l = L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + dx, dz = dxz.chunk(2, dim=1) + dout = rearrange(dout, "b l e -> e (b l)") + dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, + ctx.delta_softplus, + True # option to recompute out_z + ) + dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + dout_proj_weight, dout_proj_bias, + dA, dB, dC, dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, dC_proj_bias, None) + + +def mamba_inner_fn( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + +def mamba_inner_ref( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/src/transformers/models/zamba2/ops/triton/__init__.py b/src/transformers/models/zamba2/ops/triton/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/src/transformers/models/zamba2/ops/triton/k_activations.py b/src/transformers/models/zamba2/ops/triton/k_activations.py new file mode 100755 index 0000000000..9ded3b6654 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/k_activations.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 32}), + triton.Config({'BLOCK_N': 64}), + triton.Config({'BLOCK_N': 128}), + triton.Config({'BLOCK_N': 256}), + triton.Config({'BLOCK_N': 512}), + triton.Config({'BLOCK_N': 1024}), + ], + key=['ncols'], +) +@triton.jit +def _swiglu_fwd_kernel( + X, + Y, + OUT, + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_out_row, + ncols, + BLOCK_N: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + start_col = tl.program_id(1) * BLOCK_N + X += row * stride_x_row + Y += row * stride_y_row + OUT += row * stride_out_row + cols = start_col + tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) + out = x * tl.sigmoid(x) * y + tl.store(OUT + cols, out, mask=cols < ncols) + + +def _swiglu_fwd(xy, out=None): + if xy.stride(-1) != 1: + xy = xy.contiguous() + batch_shape = xy.shape[:-1] + xy = xy.reshape(-1, xy.shape[-1]) + x, y = xy.chunk(2, dim=-1) + if out is None: + out = torch.empty_like(x) + else: + out = out.reshape(-1, out.shape[-1]) + assert out.shape == x.shape + assert out.stride(-1) == 1 + M, N = x.shape + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) + with torch.cuda.device(x.device.index): + _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N) + return out.reshape(*batch_shape, out.shape[-1]) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 32}), + triton.Config({'BLOCK_N': 64}), + triton.Config({'BLOCK_N': 128}), + triton.Config({'BLOCK_N': 256}), + triton.Config({'BLOCK_N': 512}), + triton.Config({'BLOCK_N': 1024}), + ], + key=['ncols'], +) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) +@triton.jit +def _swiglu_bwd_kernel( + X, + Y, + DOUT, + OUT, + DX, + DY, + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dout_row, + stride_out_row, + stride_dx_row, + stride_dy_row, + ncols, + BLOCK_N: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + start_col = tl.program_id(1) * BLOCK_N + X += row * stride_x_row + Y += row * stride_y_row + DOUT += row * stride_dout_row + if RECOMPUTE_OUTPUT: + OUT += row * stride_out_row + DX += row * stride_dx_row + DY += row * stride_dy_row + cols = start_col + tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) + dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32) + x_sigmoid = tl.sigmoid(x) + dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout + dy = x * x_sigmoid * dout + tl.store(DX + cols, dx, mask=cols < ncols) + tl.store(DY + cols, dy, mask=cols < ncols) + if RECOMPUTE_OUTPUT: + out = x * x_sigmoid * y + tl.store(OUT + cols, out, mask=cols < ncols) + + +def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): + if xy.stride(-1) != 1: + xy = xy.contiguous() + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch_shape = xy.shape[:-1] + xy = xy.reshape(-1, xy.shape[-1]) + x, y = xy.chunk(2, dim=-1) + dout = dout.reshape(-1, dout.shape[-1]) + assert dout.shape == x.shape + if dxy is None: + dxy = torch.empty_like(xy) + else: + dxy = dxy.reshape(-1, dxy.shape[-1]) + assert dxy.shape == xy.shape + dx, dy = dxy.chunk(2, dim=-1) + assert dx.stride(-1) == 1 + assert dy.stride(-1) == 1 + if recompute_output: + if out is None: + out = torch.empty_like(x) + else: + out = out.reshape(-1, out.shape[-1]) + assert out.shape == x.shape + assert out.stride(-1) == 1 + M, N = x.shape + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) + with torch.cuda.device(x.device.index): + _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy, + x.stride(0), y.stride(0), dout.stride(0), + out.stride(0) if recompute_output else 0, + dx.stride(0), dy.stride(0), + N) + if not recompute_output: + return dxy.reshape(*batch_shape, dxy.shape[-1]) + else: + return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1]) diff --git a/src/transformers/models/zamba2/ops/triton/layer_norm.py b/src/transformers/models/zamba2/ops/triton/layer_norm.py new file mode 100755 index 0000000000..6fcf50e178 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/layer_norm.py @@ -0,0 +1,1086 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + is_rms_norm=False, + return_dropout_mask=False, +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(y) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + y, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + return_dropout_mask=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + return_dropout_mask=False, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + is_rms_norm, + return_dropout_mask, + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + True, + return_dropout_mask, + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/src/transformers/models/zamba2/ops/triton/layernorm.py b/src/transformers/models/zamba2/ops/triton/layernorm.py new file mode 100755 index 0000000000..989633eb17 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/layernorm.py @@ -0,0 +1,636 @@ +# Copyright (c) 2023, Tri Dao. +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6, is_rms_norm=True): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/src/transformers/models/zamba2/ops/triton/layernorm_gated.py b/src/transformers/models/zamba2/ops/triton/layernorm_gated.py new file mode 100755 index 0000000000..de4b2f4815 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/layernorm_gated.py @@ -0,0 +1,437 @@ +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange + + +def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): + dtype = x.dtype + N = x.shape[-1] + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd, + x.stride(0), out.stride(0), z.stride(0) if z is not None else 0, + M, group_size, eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + return out, mean, rstd + + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DZ, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_z_row, + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dz_row, + stride_dw_row, + stride_db_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + group = tl.program_id(1) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + group * N + if HAS_Z: + Z += row_start * stride_z_row + group * N + DZ += row_start * stride_dz_row + group * N + DY += row_start * stride_dy_row + group * N + DX += row_start * stride_dx_row + group * N + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: + B += group * N + b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + x_og = x + x = x_og * z * tl.sigmoid(z) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.) + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + y = xhat * w + b if HAS_BIAS else xhat * w + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y * z * z_sigmoid, mask=mask) + dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dy *= z * z_sigmoid + else: + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + c1 = tl.sum(xhat * wdy, axis=0) / N + if not IS_RMS_NORM: + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + dx = (wdy - xhat * c1) * rstd + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_Z and not NORM_BEFORE_GATE: + z_sigmoid = tl.sigmoid(z) + dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dx *= z * z_sigmoid + # Write dx + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_Z: + Z += stride_z_row + DZ += stride_dz_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) + + +def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None, + norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + assert dz.stride(-1) == 1 + else: + dz = torch.empty_like(z) if z is not None else None + if recompute_output: + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs + # would limit the occupancy. + nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) + _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / nrow_groups) + grid = (nrow_groups, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None, + dy, dx, _dw, _db, dz, mean, rstd, + x.stride(0), + z.stride(0) if z is not None else 0, + 0 if not recompute_output else out.stride(0), + dy.stride(0), dx.stride(0), + dz.stride(0) if dz is not None else 0, + _dw.stride(0), + _db.stride(0) if _db is not None else 0, + M, group_size, eps, + rows_per_program, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + @staticmethod + def backward(ctx, dy): + x, weight, bias, mean, rstd, z = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, + ctx.norm_before_gate, ctx.is_rms_norm) + return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None + + +def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) + + +class LayerNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/src/transformers/models/zamba2/ops/triton/selective_state_update.py b/src/transformers/models/zamba2/ops/triton/selective_state_update.py new file mode 100755 index 0000000000..193552a0f6 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/selective_state_update.py @@ -0,0 +1,263 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + # Matrix dimensions + batch, nheads, dim, dstate, nheads_ngroups_ratio, + # Strides + stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_head, stride_x_dim, + stride_dt_batch, stride_dt_head, stride_dt_dim, + stride_dt_bias_head, stride_dt_bias_dim, + stride_A_head, stride_A_dim, stride_A_dstate, + stride_B_batch, stride_B_group, stride_B_dstate, + stride_C_batch, stride_C_group, stride_C_dstate, + stride_D_head, stride_D_dim, + stride_z_batch, stride_z_head, stride_z_dim, + stride_out_batch, stride_out_head, stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt # vector of size (dstate,) + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 + else ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else + ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + batch, nheads, dim, dstate, nheads // ngroups, + state.stride(0), state.stride(1), state.stride(2), state.stride(3), + x.stride(0), x.stride(1), x.stride(2), + dt.stride(0), dt.stride(1), dt.stride(2), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), A.stride(1), A.stride(2), + B.stride(0), B.stride(1), B.stride(2), + C.stride(0), C.stride(1), C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], z_strides[1], z_strides[2], + out.stride(0), out.stride(1), out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out diff --git a/src/transformers/models/zamba2/ops/triton/ssd_bmm.py b/src/transformers/models/zamba2/ops/triton/ssd_bmm.py new file mode 100755 index 0000000000..48fd4f063e --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/ssd_bmm.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, b_ptr, out_ptr, seq_idx_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, + stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K'], +) +@triton.jit +def _bmm_chunk_bwd_kernel( + # Pointers to matrices + a_ptr, dout_ptr, db_ptr, res_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n, + stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k, + stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k, + # Meta-parameters + dot_dtype: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cs = tl.arange(0, BLOCK_SIZE_CS) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) + a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype) + a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype) + acc += tl.dot(dout, a) + dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m + a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_RESIDUAL: + res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head + res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) + res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) + acc += res + db = acc.to(db_ptr.dtype.element_ty) + + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) + tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) + + +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, dtype=out_dtype) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch, nchunks if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, b, out, seq_idx, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out + + +def _bmm_chunk_bwd(a, dout, residual=None, out=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + Return: + out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + + If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be + zeroed out before calling this function. + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + nchunks, chunk_size = dout.shape[1], dout.shape[-1] + if a.stride(-1) != 1 and a.stride(-2) != 1: + a = a.contiguous() + if dout.stride(-1) != 1 and dout.stride(-2) != 1: + dout = dout.contiguous() + if residual is not None: + assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) + if residual.stride(-1) != 1 and residual.stride(1) != 1: + residual = residual.contiguous() + # Allocates output. + if out is not None: + assert out.shape == a.shape + assert out.stride(-1) == 1 or out.stride(1) == 1 + else: + out = torch.empty_like(a) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch, + nchunks if not has_groups else nchunks * ngroups) + residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2), + residual.stride(-1)) + if residual is not None else (0, 0, 0, 0)) + with torch.cuda.device(a.device.index): + _bmm_chunk_bwd_kernel[grid]( + a, dout, out, residual, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1), + residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3], + dot_dtype, + HAS_RESIDUAL=residual is not None, + ) + return out diff --git a/src/transformers/models/zamba2/ops/triton/ssd_chunk_scan.py b/src/transformers/models/zamba2/ops/triton/ssd_chunk_scan.py new file mode 100755 index 0000000000..fc56031764 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/ssd_chunk_scan.py @@ -0,0 +1,1825 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or pid_c > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_fwd_kernel_wip( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_n = tl.program_id(axis=0) + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) + + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + # if pid_c == 0: + # if pid_b == 0: + # if pid_h == 0: + # tl.device_print("", prev_states) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # scale_m = tl.exp(dA_cs_m) + # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + # if HAS_D: + # if D_HAS_HDIM: + # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + # else: + # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + # acc += x.to(tl.float32) * D + # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): + start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) + dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) + acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + acc += x.to(tl.float32) * D + + # if HAS_Z: + # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + # acc *= z * tl.sigmoid(z) + + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) + + # TODO: this is not correct, and quite a bit slower + if start_m + BLOCK_SIZE_M < chunk_size_limit: + # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) + B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) + dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) + # TODO: seq_idx + scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m + # B *= scale + B = B.to(x_ptr.dtype.element_ty) + tmp = tl.dot(B, x) + prev_states += tmp.to(prev_states.dtype) + + C_ptrs += BLOCK_SIZE_M * stride_C_seqlen + B_ptrs += BLOCK_SIZE_M * stride_B_seqlen + cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_M * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_M * stride_dt_csize + out_ptrs += BLOCK_SIZE_M * stride_out_seqlen + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_dz_kernel( + # Pointers to matrices + dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, + stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, + stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_DDACS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head + if RECOMPUTE_OUTPUT: + outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head + if HAS_DDACS: + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) + dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) + if RECOMPUTE_OUTPUT: + outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + if RECOMPUTE_OUTPUT: + outz = out * z * z_sigmoid + tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dout *= z * z_sigmoid + tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + if HAS_DDACS: + ddA_cs = tl.sum(dout * out, axis=1) + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_scan_bwd_dstates_kernel( + # Pointers to matrices + dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, + stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) + c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale_k = tl.exp(dA_cs_k) + else: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) + dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) + c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) + acc += tl.dot(dout, c) + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + c_ptrs += BLOCK_SIZE_K * stride_c_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + out = acc.to(dprev_states_ptr.dtype.element_ty) + + dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) + tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dc_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + dc_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + dc = tl.dot(dout, prev_states) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + dc *= scale[:, None] + if HAS_DDA_CS: + ddA_cs = tl.sum(dc * c, axis=1) + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + acc += dc + dout_ptrs += stride_dout_head + prev_states_ptrs += stride_prev_states_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + # if HAS_SEQ_IDX: + # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) + tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, + dx_ptr, ddt_ptr, # dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_D_head, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + # if HAS_D: + # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Idk why limiting K_MAX gives wrong results, is it a Triton bug? + # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit + for k in range(0, K_MAX, BLOCK_SIZE_K): + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + mask = k + offs_k[None, :] >= offs_m[:, None] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + # if HAS_D: + # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) + # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) + # dD = tl.sum(x * dout, axis=0) + # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) + + +# Disabling HAS_DDA_CS for now since it's much slower +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) +@triton.jit +def _chunk_scan_bwd_dcb_kernel( + # Pointers to matrices + x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + dcb_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + if HAS_DDA_CS: + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + return + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcb = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcb *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) + dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + if HAS_DDA_CS: + tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") + ddA_cs = dcb * cb + mask = offs_m[:, None] >= offs_n[None, :] + 1 + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.cumsum(ddA_cs, axis=1) + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.sum(ddA_cs, axis=0) + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddA_cumsum_ptr, 0.0) + acc += dcb + dout_ptrs += stride_dout_head + x_ptrs += stride_x_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptr += stride_ddA_cs_head + ddA_cumsum_ptrs += stride_ddA_cs_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + mask = offs_m[:, None] >= offs_n[None, :] + acc = tl.where(mask, acc, 0.0) + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +# Not numerically stable and should not be used. Leaving here for reference. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_unstable_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, + ddA_cumsum_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + SUBTRACT_DDTDT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + ddA_cs = tl.sum(dout * out, axis=1) + if SUBTRACT_DDTDT: + dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddA_cs -= dt * ddt + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel_old( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddAcs_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + acc = tl.dot(dout, x) + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + acc *= cb + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + acc = tl.cumsum(acc, axis=1) + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddAcs_ptr, 0.0) + + # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) + # offs_k = tl.arange(0, BLOCK_SIZE_K) + # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + # dt_ptrs = dt_ptr + offs_n * stride_dt_csize + # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + # for n in range(0, chunk_size_limit_n, 64): + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) + # acc = tl.dot(dout, x) + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) + # acc *= cb + # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= dt_n + # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n + # acc = tl.where(mask, acc, 0.0) + # acc = tl.cumsum(acc, axis=1) + # acc = tl.where(mask, acc, 0.0) + # ddA_cs = tl.sum(acc, axis=0) + # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) + # # tl.store(ddAcs_ptr, 0.0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + tl.store(ddA_cumsum_ptr, 0.0) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower + lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M + # lo, hi = 0, chunk_size + for start_n in range(lo, hi, BLOCK_SIZE_N): + start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= dt_n + # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc *= cb + dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + rowsum_new = rowsum + tl.sum(acc, axis=1) + acc = rowsum[:, None] + tl.cumsum(acc, axis=1) + rowsum = rowsum_new + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) + x_ptrs += BLOCK_SIZE_N * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_N * stride_dt_csize + cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + # Need to zero out the rest, since we'll be summing the rows together + for start_n in range(hi, chunk_size, BLOCK_SIZE_N): + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_prev_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + acc = tl.dot(dout, prev_states) + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + ddA_cs = tl.sum(acc * c, axis=1) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + ddA_cs *= scale + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + + +def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + ) + return out, out_x + + +def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert B.shape == C.shape + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel_wip[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + BLOCK_SIZE_M=128, + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, out_x + + +def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): + batch, seqlen, nheads, headdim = x.shape + assert z.shape == x.shape + assert out.shape == x.shape + assert dout.shape == out.shape + nchunks = math.ceil(seqlen / chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + if has_ddAcs: + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + if D is not None: + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + if dz is not None: + assert dz.shape == z.shape + else: + dz = torch.empty_like(z) + if recompute_output: + outz = torch.empty_like(x) + dout_x = torch.empty_like(dout) + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dz_kernel[grid_dz]( + dout, out, z, x, D, outz if recompute_output else None, + dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + z.stride(0), z.stride(1), z.stride(2), z.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), + dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), + dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + if has_ddAcs else (0, 0, 0, 0)), + D is not None, + D.dim() == 2 if D is not None else True, + has_ddAcs, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + RECOMPUTE_OUTPUT=recompute_output, + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) + return return_vals if not recompute_output else (*return_vals, outz) + + +def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): + batch, seqlen, nheads, headdim = dout.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + dtype = C.dtype if dtype is None else dtype + dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) + grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(C.device.index): + _chunk_scan_bwd_dstates_kernel[grid_dstates]( + dout, C, dprev_states, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return dprev_states + + +def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if C is not None: + assert C.shape == (batch, seqlen, ngroups, dstate) + C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + else: + C_strides = (0, 0, 0, 0) + ddA_cumsum_prev = None + ddA_cumsum_prev_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) + grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_dc_kernel[grid_dc]( + dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + *C_strides, + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), + *ddA_cumsum_prev_strides, + HAS_DDA_CS=ddA_cumsum_prev is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dC = dC.sum(2) + return dC if C is None else (dC, ddA_cumsum_prev) + + +def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if CB is not None: + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) + else: + CB_strides = (0, 0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) + grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dcb_kernel[grid_dcb]( + x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + *CB_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dcb = dcb.sum(2) + if ddA_cumsum is not None: + BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return dcb if CB is None else (dcb, ddA_cumsum) + + +def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + # if D is not None: + # BLOCK_SIZE_M_min = 32 + # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) + # else: + # dD = None + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dx_kernel[grid_dx]( + x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + D.stride(0) if D is not None else 0, + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + ) + # if D is not None: + # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + return dx, ddt.to(dtype=dt.dtype) + + +def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): + """Not numerically stable and should not be used. Leaving here for reference. + """ + + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert ddt.shape == dt.shape + assert out.shape == x.shape + assert dout.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + ddA_cumsum = torch.empty_like(dt) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + if D is not None: # Triton gives wrong results if we write to the same location + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( + dout, out, dt, ddt, x, D, ddA_cumsum, dD, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + subtract_ddtdt, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return ddA_cumsum, dD + + +def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 32 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + ngroups = C.shape[2] + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( + dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + return ddA_cumsum_prev + + +class ChunkScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + CB = _bmm_chunk_fwd(C, B, chunk_size) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) + ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) + return out + + @staticmethod + def backward(ctx, dout): + if dout.stride(-1) != 1: + dout = dout.contiguous() + out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + if z is not None: + dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) + else: + dz = None + dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) + dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) + dC = dC.to(C.dtype) + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + dB = _bmm_chunk_bwd(C, dCB) + dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) + dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + if z is not None: + ddA_cumsum -= ddt * dt + else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz + ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) + ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) + return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz + + +def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) + + +def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out if z is None else out * F.silu(z) diff --git a/src/transformers/models/zamba2/ops/triton/ssd_chunk_state.py b/src/transformers/models/zamba2/ops/triton/ssd_chunk_state.py new file mode 100755 index 0000000000..4333e6a748 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/ssd_chunk_state.py @@ -0,0 +1,866 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + # Matrix dimension + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_bwd_kernel( + # Pointers to matrices + ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr, + ddt_ptr, dA_ptr, ddt_bias_ptr, + # Matrix dimensions + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize, + stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize, + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, + stride_dA_head, + stride_ddt_bias_head, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk + ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) + ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + ddt = ddA * A[:, None] + ddt_out + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt_presoftplus = dt + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt) + clamp_mask = (dt < dt_min) | (dt > dt_max) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) + ddt = tl.where(clamp_mask, 0.0, ddt) + if DT_SOFTPLUS: + ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) + tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + dA = tl.sum(ddA * dt, axis=1) + tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + if HAS_DT_BIAS: + ddt_bias = tl.sum(ddt, axis=1) + tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, + dx_ptr, ddt_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + ddA_cs = -(ddt * dt_m) + ddA_cs_last = -tl.sum(ddA_cs) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + + dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_state_bwd_db_kernel( + # Pointers to matrices + x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + db_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + dstates = dstates.to(x_ptrs.dtype.element_ty) + db = tl.dot(x, dstates) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + db *= (scale * dt_m)[:, None] + if HAS_DDA_CS: + # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum + ddA_cs = tl.sum(db * b, axis=1) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + acc += db + x_ptrs += stride_x_head + dstates_ptrs += stride_states_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # if HAS_SEQ_IDX: + # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) + tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + acc *= scale[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + # ddA_cs = -(ddt * dt_m) + # Triton 2.2.0 errors if we have the cumsum here, so we just write it out + # then call torch.cumsum outside this kernel. + # ddA_cs = tl.cumsum(ddt * dt_m) + ddA_cs = ddt * dt_m + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, A, dt_bias, dt_out, dA_cumsum, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): + batch, seqlen, nheads = dt.shape + _, _, nchunks, chunk_size = ddA.shape + assert ddA.shape == (batch, nheads, nchunks, chunk_size) + assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + else: + ddt_bias = None + if ddt is not None: + assert ddt.shape == dt.shape + else: + ddt = torch.empty_like(dt) + dA = torch.empty_like(A, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_bwd_kernel[grid_chunk_cs]( + ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), + ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3), + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + ddt.stride(0), ddt.stride(1), ddt.stride(2), + dA.stride(0), + ddt_bias.stride(0) if ddt_bias is not None else 0, + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return ddt, dA, ddt_bias + + +def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, B, states, dt, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dx is not None: + assert dx.shape == x.shape + else: + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_dx_kernel[grid_dx]( + x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) + + +def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + dstate = dstates.shape[-1] + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B is not None: + assert B.shape == (batch, seqlen, ngroups, dstate) + B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + else: + B_strides = (0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) + grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_db_kernel[grid_db]( + x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + *B_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dB = dB.sum(2) + if ddA_cumsum is not None: + # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute + # to the state of the chunk. + # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + # But it's easier to just do the cumsum for all elements, the result will be the same. + torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum) + return dB if B is None else (dB, ddA_cumsum) + + +def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + return ddA_cumsum + + +class ChunkStateFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if B.stride(-1) != 1: + B = B.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) + ctx.save_for_backward(B, x, dt, dA_cumsum) + return states + + @staticmethod + def backward(ctx, dstates): + B, x, dt, dA_cumsum = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dstates.stride(-1) != 1: + dstates = dstates.contiguous() + dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) + dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) + dB = dB.to(B.dtype) + return dB, dx, ddt, ddA_cumsum, None + + +def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) + + +def chunk_state_ref(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) diff --git a/src/transformers/models/zamba2/ops/triton/ssd_combined.py b/src/transformers/models/zamba2/ops/triton/ssd_combined.py new file mode 100755 index 0000000000..c93bf57bfa --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/ssd_combined.py @@ -0,0 +1,959 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +from typing import Optional + +import math +from packaging import version + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn, causal_conv1d_cuda = None, None + +from ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd +from ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable +from ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd +from ops.triton.ssd_state_passing import state_passing, state_passing_ref +from ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates +from ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb +from ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable +from ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref +from ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev +from ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd +from ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, + b_ptr, dstates_ptr, + dx_ptr, ddt_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_D_head, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + # However, we're getting error with the Triton compiler 2.1.0 for that code path: + # Unexpected mma -> mma layout conversion + # Triton 2.2.0 fixes this + offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate + acc *= scale[:, None] + + # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + # dt_ptrs = dt_ptr + offs_m * stride_dt_csize + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + # ddt = tl.sum(acc * x, axis=1) * dt_m + # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit + K_MIN = pid_m * BLOCK_SIZE_M + cb_ptrs += K_MIN * stride_cb_csize_k + dout_ptrs += K_MIN * stride_dout_seqlen + dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize + for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): + k = tl.multiple_of(k, BLOCK_SIZE_K) + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + mask = k + offs_k[None, :] >= offs_m[:, None] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + dD = tl.sum(dout_res * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + else: + dD = tl.sum(dout_res * x) + tl.store(dD_ptr, dD) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + if dx is None: + dx = torch.empty_like(x) + else: + assert dx.shape == x.shape + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( + x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + D.stride(0) if D is not None else 0, + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22 + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return dx, ddt.to(dtype=dt.dtype), dD + + +def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) + # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) + # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) + # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) + states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + return out, out_x, dt, dA_cumsum, states, final_states + + +def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, + dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, + dt_limit=(0.0, float("inf")), + dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch, seqlen, nheads, headdim = x.shape + nchunks = math.ceil(seqlen / chunk_size) + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert C.shape == B.shape + assert out.shape == x.shape + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if dx is not None: + assert dx.shape == x.shape + if dB is not None: + assert dB.shape == B.shape + dB_given = dB + else: + dB_given = torch.empty_like(B) + if dC is not None: + assert dC.shape == C.shape + dC_given = dC + else: + dC_given = torch.empty_like(C) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + if ddt is not None: + assert ddt.shape == dt.shape + ddt_given = ddt + else: + ddt_given = torch.empty_like(dt) + # TD: For some reason Triton (2.1.0 and 2.2.0) errors with + # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. + dt_in = dt.clone() + dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, + dt_limit=dt_limit) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + if z is not None: + dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) + outz = rest[0] if recompute_output else out + else: + dz = None + outz = out + dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) + # dstates has length nchunks, containing the gradient to initial states at index 0 and + # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) + # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states + # will be used in matmul in the next kernels. + dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(dstates, "... p n -> ... (p n)"), + dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, + seq_idx=seq_idx, + has_initial_states=initial_states is not None, + dstates_dtype=x.dtype, + states_dtype=x.dtype, + chunk_size=chunk_size, + ) + # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and + # gradient to the final states at index (nchunks - 1) + # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) + # The final states is not stored. + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) + dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None + dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) + # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) + dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) + # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) + # Computing ddA with the dcb kernel is much slower, so we're not using it for now + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) + _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) + # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate + # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 + if z is None: + dD = dD_from_x + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might + # be a lot of underflow. + + # This is already done as part of bwd_dC kernel + # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) + ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum + ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) + # This is already done as part of bwd_dB kernel + # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) + # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] + ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) + ddA += ddA_next + ddA_prev + + ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) + + # These 2 lines are just to test ddt and dA being computed by old code + # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) + # ddt_given.copy_(ddt) + + return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) + return return_vals if not recompute_output else (*return_vals, outz) + + +def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): + """ + Argument: + dout: (batch, seqlen, nheads, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + import selective_scan + + batch, seqlen, nheads, headdim = x.shape + chunk_size = dt.shape[-1] + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + x = rearrange(x, "b l h p -> b (h p) l") + squeeze_dt = dt.dim() == 4 + if dt.dim() == 4: + dt = repeat(dt, "b h c l -> b h p c l", p=headdim) + dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim) + squeeze_A = A.dim() == 1 + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + + if x.stride(-1) != 1: + x = x.contiguous() + if dt.stride(-1) != 1: + dt = dt.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False) + if z is not None: + out = rest[0] + else: + out = None + + dout = rearrange(dout, "b l h p -> b (h p) l") + + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + _, ddt, dA, *rest = selective_scan.bwd( + x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False, + False # option to recompute out_z, not used here + ) + ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size) + if squeeze_dt: + ddt = ddt.float().sum(dim=2) + if squeeze_A: + dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2)) + return ddt, dA + + +class MambaChunkScanCombinedFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): + ctx.dt_dtype = dt.dtype + out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) + ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) + ctx.dt_softplus = dt_softplus + ctx.chunk_size = chunk_size + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + return out if not return_final_states else (out, final_states) + + @staticmethod + def backward(ctx, dout, *args): + out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) + return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None + + +def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) + + +def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True) + # 2. Pass the state to all the chunks by weighted cumsum. + states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + # 3. Compute the output for each chunk + out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state_ref(B, x, dt, dA_cumsum) + states_dtype = states.dtype + if states.dtype not in [torch.float32, torch.float64]: + states = states.to(torch.float32) + # 2. Pass the state to all the chunks by weighted cumsum. + # state_passing_ref is much less numerically stable + states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + states = states.to(states_dtype) + # 3. Compute the output for each chunk + out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) or (nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + x = rearrange(x, "b l h p -> b (h p) l") + if dt.dim() == 3: + dt = repeat(dt, "b l h -> b l h p", p=headdim) + dt = rearrange(dt, "b l h p -> b (h p) l") + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + if dt_bias is not None: + if dt_bias.dim() == 1: + dt_bias = repeat(dt_bias, "h -> h p", p=headdim) + dt_bias = rearrange(dt_bias, "h p -> (h p)") + if dt_limit != (0.0, float("inf")): + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "d -> d 1") + if dt_softplus: + dt = F.softplus(dt) + dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype) + dt_bias = None + dt_softplus = None + out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus) + return rearrange(out, "b (h p) l -> b l h p", p=headdim) + + +def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None, + dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + activation="silu", headdim=None, ngroups=1): + """ + Argument: + xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, dim) + dt_bias: (nheads) or (nheads, headdim) + headdim: if D is 1D and z is None, headdim must be passed in + Return: + out: (batch, seqlen, dim) + """ + batch, seqlen, nheads = dt.shape[:3] + assert nheads % ngroups == 0 + if z is not None: + dim = z.shape[-1] + assert dim % nheads == 0 + headdim = dim // nheads + else: + if D.dim() == 1: + assert headdim is not None + else: + headdim = D.shape[1] + dim = nheads * headdim + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + dstate = (xBC.shape[-1] - dim) // ngroups // 2 + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + return rearrange(out, "b s h p -> b s (h p)") + + +class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", + rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, + ngroups=1, norm_before_gate=True): + assert activation in [None, "silu", "swish"] + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + assert nheads % ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), + "b d s -> b s d" + ) + x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + if rmsnorm_weight is None: + out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) + out = rearrange(out, "b s h p -> b s (h p)") + rstd = None + if d_nonssm > 0: + out = torch.cat([_swiglu_fwd(zx0), out], dim=-1) + else: + out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) + # reshape input data into 2D tensor + x_rms = rearrange(out_x, "b s h p -> (b s) (h p)") + z_rms = rearrange(z, "b s h p -> (b s) (h p)") + rmsnorm_weight = rmsnorm_weight.contiguous() + if d_nonssm == 0: + out = None + else: + out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device) + out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d") + _swiglu_fwd(zx0, out=out01[..., :d_nonssm]) + out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out, + group_size=dim // ngroups, + norm_before_gate=norm_before_gate, is_rms_norm=True) + if d_nonssm == 0: + out = rearrange(out, "(b s) d -> b s d", b=batch) + else: + out = out01 + ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None + if outproj_weight is not None: + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + out, outproj_weight = out.to(dtype), outproj_weight.to(dtype) + outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None + out = F.linear(out, outproj_weight, outproj_bias) + else: + assert outproj_bias is None + ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias, + out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias) + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.activation = activation + ctx.rmsnorm_eps = rmsnorm_eps + ctx.norm_before_gate = norm_before_gate + ctx.chunk_size = chunk_size + ctx.headdim = headdim + ctx.ngroups = ngroups + return out if not return_final_states else (out, final_states) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + headdim = ctx.headdim + nheads = D.shape[0] + dim = nheads * headdim + assert nheads % ctx.ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + recompute_output = outproj_weight is not None + if recompute_output: + out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype) + out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) + # Recompute x, B, C + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), + "b d s -> b s d" + ) + x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups) + dzxbcdt = torch.empty_like(zxbcdt) + dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) + dxBC = torch.empty_like(xBC) + dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads) + dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups) + dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups) + if outproj_weight is not None: + dout_og = dout + dout = F.linear(dout, outproj_weight.t()) + if d_nonssm > 0: + dout0, dout = dout.split([d_nonssm, dim], dim=-1) + _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute) + dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim) + if rmsnorm_weight is None: + dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads) + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output + ) + out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None + drmsnorm_weight = None + else: + batch = dout.shape[0] + dy_rms = rearrange(dout, "b s h p -> (b s) (h p)") + dz = rearrange(dz, "b l d -> (b l) d") + x_rms = rearrange(out, "b s h p -> (b s) (h p)") + z_rms = rearrange(z, "b s h p -> (b s) (h p)") + out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None + dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None) + out_for_linear = out_recompute if recompute_output else None + dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) + dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC + ) + + if outproj_weight is not None: + doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear) + doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None + else: + doutproj_weight, doutproj_bias = None, None + dxBC_given = rearrange(dxBC_given, "b s d -> b d s") + dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] + ) + dxBC_given = rearrange(dxBC_given, "b d s -> b s d") + return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None + + +def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen), int32 + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) + + +def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + assert nheads % ngroups == 0 + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2 + assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + if rmsnorm_weight is not None: + assert rmsnorm_weight.shape == (dim,) + z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1) + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), + z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit) + out = rearrange(out, "b s h p -> b s (h p)") + if rmsnorm_weight is not None: + out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps, + norm_before_gate=norm_before_gate) + if outproj_weight is not None: + out = F.linear(out, outproj_weight, outproj_bias) + return out + diff --git a/src/transformers/models/zamba2/ops/triton/ssd_state_passing.py b/src/transformers/models/zamba2/ops/triton/ssd_state_passing.py new file mode 100755 index 0000000000..63863b8236 --- /dev/null +++ b/src/transformers/models/zamba2/ops/triton/ssd_state_passing.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_final_states_batch, stride_final_states_head, stride_final_states_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_initstates_batch, stride_initstates_head, stride_initstates_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_bwd_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, + dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, + stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, + # Meta-parameters + CONVERT_STATES: tl.constexpr, + HAS_DFINAL_STATES: tl.constexpr, + HAS_DINITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk + ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk + if CONVERT_STATES: + states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + if HAS_DFINAL_STATES: + dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head + if HAS_DINITSTATES: + dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + dout_ptrs = dout_ptr + offs_m * stride_dout_dim + if CONVERT_STATES: + states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim + + if HAS_DFINAL_STATES: + dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + if HAS_SEQ_IDX: + seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) + dstates_ptrs -= stride_dstates_chunk + for c in range(nchunks - 1): + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if CONVERT_STATES: + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + dout_ptrs -= stride_dout_chunk + dstates_ptrs -= stride_dstates_chunk + dA_cs_ptr -= stride_dA_cs_chunk + ddA_cs_ptr -= stride_ddA_cs_chunk + out_ptrs -= stride_out_chunk + if CONVERT_STATES: + states_converted_ptrs -= stride_out_chunk + if CONVERT_STATES: + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + if not HAS_DINITSTATES: + tl.store(ddA_cs_ptr, 0.0) + else: + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + scale = tl.where(seq_idx == 0, scale, 0.0) + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) + + +def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, + out_dtype=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + final_states.stride(0), final_states.stride(1), final_states.stride(2), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, final_states + + +def _state_passing_bwd( + states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, + dstates_dtype=None, states_dtype=None, chunk_size=None +): + """ + states contains the initial_states at index 0. The final states are not included in states. + """ + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dout.shape == (batch, nchunks, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + if states_dtype is not None and states_dtype != states.dtype: + states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + assert states_converted.stride() == states.stride() + else: + states_converted = None + if has_initial_states: + dinitstates = torch.empty_like(dstates[:, 0]) + else: + dinitstates = None + if dfinal_states is not None: + assert dfinal_states.shape == (batch, nheads, dim) + BLOCK_SIZE_min = 64 + n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min + ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, + dtype=torch.float32, device=dA_chunk_cumsum.device) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(dout.device.index): + _state_passing_bwd_kernel[grid]( + dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, + dstates, ddA_chunk_cumsum, dinitstates, states_converted, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) + if dfinal_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), + ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), + *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) + if dinitstates is not None else (0, 0, 0)), + CONVERT_STATES=states_converted is not None, + HAS_DFINAL_STATES=dfinal_states is not None, + HAS_DINITSTATES=dinitstates is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] + n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) + if states_dtype is not None and states_dtype == states.dtype: + states_converted = states + return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) + + +class StatePassingFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, states, dA_chunk_cumsum, initial_states=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if states.stride(-1) != 1: + states = states.contiguous() + out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) + ctx.save_for_backward(out, dA_chunk_cumsum) + ctx.has_initial_states = initial_states is not None + return out, final_states + + @staticmethod + def backward(ctx, dout, dfinal_states): + out, dA_chunk_cumsum = ctx.saved_tensors + batch, nchunks, nheads, dim = out.shape + assert dout.shape == (batch, nchunks, nheads, dim) + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dfinal_states.shape == (batch, nheads, dim) + if dout.stride(-1) != 1: + dout = dout.contiguous() + dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( + out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states + ) + return dstates, ddA_chunk_cumsum, dinitstates + + +def state_passing(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) + + +def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + if initial_states is None: + initial_states = torch.zeros_like(states[:, 0]) + states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) + dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) + dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) + nchunks = dA_chunk_cumsum.shape[-1] + # (batch, nheads, nchunks, nchunks) + dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] + # (batch, nheads, nchunks, nchunks) + decay_chunk = torch.exp(dt_chunk_segment_sum) + causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) + decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) + out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) + return out[:, :-1], out[:, -1] diff --git a/src/transformers/models/zamba2/requirements.txt b/src/transformers/models/zamba2/requirements.txt new file mode 100644 index 0000000000..90254899b3 --- /dev/null +++ b/src/transformers/models/zamba2/requirements.txt @@ -0,0 +1,40 @@ +causal-conv1d==1.2.0.post2 +certifi==2024.2.2 +charset-normalizer==3.3.2 +einops==0.7.0 +filelock==3.13.4 +fsspec==2024.3.1 +huggingface-hub==0.22.2 +idna==3.7 +Jinja2==3.1.3 +mamba-ssm==1.2.0.post1 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.3 +ninja==1.11.1.1 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.1.105 +packaging==24.0 +PyYAML==6.0.1 +regex==2023.12.25 +requests==2.31.0 +safetensors==0.4.2 +sympy==1.12 +tokenizers==0.15.2 +torch==2.2.2 +tqdm==4.66.2 +transformers==4.39.3 +triton==2.2.0 +typing_extensions==4.11.0 +urllib3==2.2.1 diff --git a/src/transformers/models/zamba2/rotary.py b/src/transformers/models/zamba2/rotary.py new file mode 100644 index 0000000000..a905ffae42 --- /dev/null +++ b/src/transformers/models/zamba2/rotary.py @@ -0,0 +1,76 @@ +from __future__ import annotations +import torch +from torch import Tensor, nn + +class RotaryEmbedding(nn.Module): + + def __init__( + self, kv_channels: int, rotary_percent: float, seq_len_interpolation_factor: float = None + ) -> None: + super().__init__() + + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.inv_freq = 1.0 / ( + 10000 + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) + / dim + ) + ) + + def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.outer(seq, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb[:, None, None, :] + return emb + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + state_dict.pop(f'{prefix}inv_freq', None) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def get_rotary_seq_len( + self, + inference_params, + transformer: TransformerBlock, + transformer_input: Tensor, + ) -> float: + if inference_params is not None: + rotary_seq_len = inference_params.max_sequence_length + else: + if transformer.input_tensor is not None: + rotary_seq_len = transformer.input_tensor.size(0) + else: + rotary_seq_len = transformer_input.size(0) + + return rotary_seq_len + + +def _rotate_half(x: Tensor) -> Tensor: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: + # t.shape: (s, b, h, d), freqs.shape: (s, 1, 1, d) + rot_dim = freqs.shape[-1] + + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t) * sin_) + + return torch.cat((t, t_pass), dim=-1) \ No newline at end of file diff --git a/src/transformers/models/zamba2/selective_scan_interface.py b/src/transformers/models/zamba2/selective_scan_interface.py new file mode 100755 index 0000000000..676971fef0 --- /dev/null +++ b/src/transformers/models/zamba2/selective_scan_interface.py @@ -0,0 +1,389 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +from causal_conv1d import causal_conv1d_fn +import causal_conv1d_cuda +global SELECTIVE_SCAN_CUDA_IMPORT_FAILED +SELECTIVE_SCAN_CUDA_IMPORT_FAILED = False +try: + import selective_scan_cuda +except Exception as e: + SELECTIVE_SCAN_CUDA_IMPORT_FAILED = True + print("SELECTIVE SCAN CUDA IMPORT FAILED: ", e) + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, d_inner, + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + #out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + """ + xz: (batch, dim, seqlen) + """ + assert checkpoint_lvl in [0, 1] + batch, L = xz.shape[0], xz.shape[-1] + # print("xz",xz.shape, "conv1dweight", conv1d_weight.shape, "conv1dbias: ", conv1d_bias.shape, "xproj: ", x_proj_weight.shape, "delta proj: ", delta_proj_weight.shape) + # print("DELTA PROJ WEIGHT: ", delta_proj_weight.shape) + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + ## out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + ## out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) + ## if out_proj_bias is not None else None) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + + # x, z = xz.chunk(2, dim=1) + x, z = xz.view(batch, d_inner, 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ## ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.d_inner = d_inner + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, + delta_proj_weight, ## out_proj_weight, + conv1d_out, delta, + A, B, C, D, delta_bias, scan_intermediates, out) + return rearrange(out_z, "b d l -> b l d") + ## return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, ## out_proj_weight, + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + batch, L = xz.shape[0], xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + + # x, z = xz.chunk(2, dim=1) + x, z = xz.view(batch, ctx.d_inner, 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + # if dout.stride(-1) != 1: + # dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), + "d (b l) -> b d l", l = L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + # dx, dz = dxz.chunk(2, dim=1) + dx, dz = dxz.view(batch, ctx.d_inner, 2, L).chunk(2, dim=2) + dx = dx.squeeze(2) + dz = dz.squeeze(2) + ## dout = rearrange(dout, "b l e -> e (b l)") + ## dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dout = rearrange(dout, "b l e -> b e l") + if dout.stride(-1) != 1: + dout = dout.contiguous() + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, dout, ## dout_y, + scan_intermediates, out, dz, + ctx.delta_softplus, + True # option to recompute out_z + ) + ## dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + #dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + ## dout_proj_bias = dout.sum(dim=(1,2)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + # print("MAMBA_INNER_FN_BACKWARD: ") + # print("DXZ: ", dxz.shape, dxz.sum().item()) + # print("dconv1d_weight", dconv1d_weight.shape,dconv1d_weight.sum().item()) + # print("dx_proj_weight",dx_proj_weight.shape,dconv1d_weight.sum().item()) + # print("ddelta_proj_weight",ddelta_proj_weight.shape,ddelta_proj_weight.sum().item()) + # print("dout_proj_weight", dout_proj_weight.shape,dout_proj_weight.sum().item()) + # print("dout_proj_bias", dout_proj_bias.shape,dout_proj_bias.sum().item()) + # print("dA", dA.shape,dA.sum().item()) + # print("dB", dB.shape if dB is not None else "NONE!!!") + # print("dC", dC.shape if dC is not None else "NONE!!!") + # print("dD",dD.shape if dD is not None else "NONE!!!") + # #die + return (None, None, + dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + ## dout_proj_weight, dout_proj_bias, + dA, dB, dC, dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, dC_proj_bias, None) + + +def mamba_inner_fn( + d_inner, + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + #out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + return MambaInnerFn.apply(d_inner, + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + #out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + +def mamba_inner_ref( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/src/transformers/models/zamba2/setup.py b/src/transformers/models/zamba2/setup.py new file mode 100644 index 0000000000..0569e9e21f --- /dev/null +++ b/src/transformers/models/zamba2/setup.py @@ -0,0 +1,158 @@ +import warnings +import os +from pathlib import Path + +from packaging.version import parse, Version +from setuptools import setup, find_packages +import subprocess + + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + +PACKAGE_NAME = "Zamba" +VERSION = "0.0.1" + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +ext_modules = [] +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none(PACKAGE_NAME) + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + ext_modules.append( + CUDAExtension( + name="selective_scan_cuda", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/selective_scan_fwd_fp32.cu", + "csrc/selective_scan/selective_scan_fwd_fp16.cu", + "csrc/selective_scan/selective_scan_fwd_bf16.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + ), + }, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) + ) + + +setup( + name=PACKAGE_NAME, + version=VERSION, + description="Zamba state space model with global shared attention block", + long_description=long_description, + long_description_content_type="text/markdown", + packages=find_packages(include=['ops'],), + exclude=( + "csrc", + "zamba.egg-info", + ), + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + python_requires=">=3.7", + install_requires=[ + "torch", + "packaging", + "ninja", + "einops", + "triton", + "transformers", + "causal-conv1d==1.2.0.post2", + ], +) \ No newline at end of file diff --git a/src/transformers/models/zamba2/switch_mlp.py b/src/transformers/models/zamba2/switch_mlp.py new file mode 100644 index 0000000000..6f5b4dc00f --- /dev/null +++ b/src/transformers/models/zamba2/switch_mlp.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import pickle +import os +import torch.nn.functional as F + +from mamba_config import MambaConfig +from mlp import MLP + +def sinkhorn(cost, tol=0.0001): + "Sinkhorn based MoE routing function" + cost = torch.exp(2.0 * cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = 1 / (cost.size(1) * torch.sum(cost, 0)) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) + d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) + error = torch.mean(torch.abs(d1_old - d1)) + d1_old = d1 + return d1 * cost * d0.unsqueeze(1) + + +class SwitchMLP(nn.Module): + + def __init__(self, config: MambaConfig, layer_idx=None): + super().__init__() + + self.layer = layer_idx + self.config: MambaConfig = config + if config.layer_mapping: + self.num_moe_experts = int(config.layer_mapping[layer_idx-1][-1]) + else: + self.num_moe_experts = self.config.num_moe_experts + self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts) + self.add_bias = config.add_bias_linear + self.routing = config.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2' + self.route_algo = sinkhorn + self.router_activation = torch.sigmoid + + self.num_local_experts = self.num_moe_experts + self.local_expert_indices = [i for i in range(self.num_local_experts)] + + self.local_experts = torch.nn.ModuleList() + for _ in range(self.num_local_experts): + expert = MLP(self.config, is_expert=True, layer_idx=layer_idx) + self.local_experts.append(expert) + + def gather_indices(self, local_indices): + return local_indices + + def forward(self, hidden_states, inference_params=None): + + hidden_shape = hidden_states.shape + route = self.router(hidden_states) + route = route.view(-1, self.num_moe_experts) + + if self.routing == 'sinkhorn': + route = self.router_activation(route) + max_prob, max_ind = torch.max(route, dim=1) + else: + route = torch.softmax(route, dim=1) + max_prob, max_ind = torch.max(route, dim=1) + + max_prob = torch.unsqueeze(max_prob, 1) + hidden_states = hidden_states.view(-1, hidden_shape[-1]) + + global_hidden_states = hidden_states + global_indices = max_ind + output_total = torch.zeros_like(global_hidden_states) + + for expert_num, expert in enumerate(self.local_experts): + local_expert_index = self.local_expert_indices[expert_num] + local_indices = (global_indices == local_expert_index).nonzero() + hidden = global_hidden_states[local_indices, :] + output = expert(hidden) + output_total[local_indices, :] = output + + output_total = output_total * max_prob + output_total = output_total.view(hidden_shape) + + return output_total \ No newline at end of file diff --git a/src/transformers/models/zamba2/test.py b/src/transformers/models/zamba2/test.py new file mode 100644 index 0000000000..262fa2b188 --- /dev/null +++ b/src/transformers/models/zamba2/test.py @@ -0,0 +1,28 @@ +from mamba_model import MambaModel +from mamba_config import MambaConfig +import torch + +config = MambaConfig( + num_layers = 5, + hidden_size = 1024, + mamba_headdim = 64, + mamba_ngroups = 1, + state_size = 16, + conv_dimension = 4, + expansion_factor = 2, + rms_norm = True, + bias = False, + use_mem_mlp = True, + num_attention_heads = 16, + num_mem_heads = 16, + num_mem_blocks = 2, + vocab_size = 50000, + layer_mapping = ["m", "m", "g", "m", "m"] + +) +#layer_mapping = ["r", "r", "g", "r", "r"] +model = MambaModel(config = config, max_sequence_length = 4096) +model = model.cuda().half() +inputs = torch.tensor([1, 2]).cuda().long().unsqueeze(0) +out = model(inputs) +print("OUT", out) \ No newline at end of file diff --git a/src/transformers/models/zamba2/transformers_tests.ipynb b/src/transformers/models/zamba2/transformers_tests.ipynb new file mode 100644 index 0000000000..b28bf3af73 --- /dev/null +++ b/src/transformers/models/zamba2/transformers_tests.ipynb @@ -0,0 +1,363 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "43fcda9ded7e4ec08f5f7ebff304bea7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/608 [00:00)\n" + ] + } + ], + "source": [ + "text = \"Sample input.\"\n", + "\n", + "input_ids = tokenizer(text, return_tensors=\"pt\").input_ids\n", + "%time output = model(input_ids)\n", + "print(output.logits)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers.models.zamba2 import Zamba2Config, Zamba2Model\n", + "config = Zamba2Config(# hidden_size=256,\n", + " use_cache=False,\n", + " use_mamba_kernels=True)\n", + "\n", + "model = Zamba2Model(config)\n", + "model = model.cuda()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was provided, so no cache will be returned.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ".glu at 0x720f6c155120>\n", + "tensor([[[ 0.2320, -0.0671, 0.4281, ..., 0.2731, -0.2485, -0.1961],\n", + " [-0.3529, -0.3817, 0.1831, ..., 0.4370, 0.1053, 0.2115]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.2456, -0.1861, -0.3472, ..., 0.2924, -0.3361, 0.0805],\n", + " [-0.2854, -0.3549, 0.2099, ..., 0.2326, -0.2619, 0.0898]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[ 0.3210, 0.0659, 0.5195, ..., 0.2646, -0.4936, -0.3867],\n", + " [-0.0319, 0.0588, 0.0368, ..., 0.2150, -0.6552, -0.3597]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.2532, -0.1881, -0.3141, ..., 0.6968, -0.3180, -0.1367],\n", + " [-0.5472, -0.2829, 0.1529, ..., 0.4406, -0.1009, -0.1136]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[ 0.2028, 0.5285, 0.4683, ..., 0.4778, -0.1321, -0.1004],\n", + " [-0.1195, 0.1437, -0.2211, ..., 0.3486, -0.2087, 0.0910]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.0541, -0.1016, -0.1723, ..., 0.2633, 0.2065, -0.2020],\n", + " [-0.2931, -0.6119, 0.3172, ..., 0.2568, 0.1691, -0.4037]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[ 0.0046, 0.8110, 0.4117, ..., 0.4044, -0.2058, 0.0540],\n", + " [-0.1387, 0.6083, -0.0321, ..., 0.3768, -0.2296, -0.0283]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.4597, 0.0150, 0.1303, ..., 0.2195, -0.0543, -0.1967],\n", + " [-0.6164, -0.3430, 0.2792, ..., 0.0399, 0.0801, -0.6702]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[ 0.0241, 0.6656, 0.1877, ..., 0.0557, -0.2189, 0.1564],\n", + " [-0.3077, 0.5369, -0.1785, ..., 0.1834, -0.2878, 0.1950]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.2831, 0.1347, 0.1499, ..., 0.0057, 0.1145, -0.0469],\n", + " [-0.3922, -0.2633, 0.0517, ..., -0.1507, -0.0124, -0.4252]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[-0.0641, 0.8043, 0.4550, ..., 0.2877, -0.0341, 0.3485],\n", + " [-0.2425, 0.7751, 0.0793, ..., 0.4592, 0.1067, 0.2118]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720e33f02b90>\n", + "tensor([[[-0.0197, 0.0272, 0.3330, ..., 0.0538, -0.0269, -0.1854],\n", + " [-0.3365, -0.2644, 0.1057, ..., -0.1096, 0.0154, -0.5892]]],\n", + " device='cuda:0', grad_fn=)\n", + ".glu at 0x720f6c155120>\n", + "tensor([[[-0.4278, 0.5207, 0.1612, ..., 0.1028, 0.2516, 0.2261],\n", + " [-0.5387, 0.3975, 0.0550, ..., 0.2629, 0.2234, 0.3841]]],\n", + " device='cuda:0', grad_fn=)\n" + ] + }, + { + "data": { + "text/plain": [ + "BaseModelOutputWithPast(last_hidden_state=tensor([[[ 1.5610, -0.2107, -0.1899, -0.6792, 0.1004, -1.4207, 0.4914,\n", + " -2.6973, -1.5301, 0.8072, 0.1398, 0.3223, 0.7193, -0.6386,\n", + " -0.0224, -0.0305, 1.0239, 1.1427, 0.5567, -1.2266, 1.0427,\n", + " -0.2688, 0.3386, -0.3193, -0.7158, 0.4627, -0.0887, -0.3950,\n", + " 0.4347, 0.6391, -0.1739, -1.6352, 1.1399, 0.4062, 0.2984,\n", + " -1.9093, 0.3013, 2.0630, -0.8852, -0.8784, 0.8980, -0.7977,\n", + " 0.8236, -0.5900, -0.1444, 1.0169, -0.0153, 2.7408, -1.3365,\n", + " -0.1727, -0.4123, -2.5188, 0.4768, -1.0458, 0.2904, -0.3563,\n", + " 1.6769, -0.3954, -0.3084, -0.5172, -1.7214, 1.7523, 0.7369,\n", + " 1.1858, 0.1346, 0.0534, 0.9420, 1.2791, -0.2385, 0.3122,\n", + " 0.8081, -1.1902, 1.1936, 0.7331, -1.4534, 0.5935, 0.7391,\n", + " -0.9888, -0.2075, 0.6612, 0.8746, -0.3961, -0.5769, -1.5450,\n", + " 1.4662, -0.3326, 1.2070, -0.0269, -0.3110, 1.7675, 0.9386,\n", + " 0.0231, -0.5033, 0.4606, 0.0376, -0.6718, -2.2408, -0.4152,\n", + " -0.0150, 1.6197, -0.2152, -0.9927, -0.0088, -0.1774, -0.8367,\n", + " -0.2404, 1.7812, 0.6681, -1.5299, -1.3167, 1.1307, 2.4472,\n", + " 0.9099, -0.4268, -0.6490, 0.0145, 1.7519, 0.6995, 0.2697,\n", + " 0.2662, 0.0143, -1.0016, -0.6653, -0.6398, -0.4175, -0.8453,\n", + " 1.8385, -0.2844, -1.6901, 0.6662, -0.4793, -0.1196, 1.4028,\n", + " 1.1849, -1.4733, 0.5708, 0.1397, 0.7665, 1.4205, -1.6913,\n", + " 0.2976, 1.3293, 0.4003, 1.6835, 0.1577, 0.0323, 0.2776,\n", + " -0.9257, 0.5199, 0.3320, 0.2912, -1.1751, -0.0050, -0.2424,\n", + " -1.1217, -0.6652, -0.0812, -1.1894, -0.1066, 0.4753, -0.6588,\n", + " 1.1274, -2.2381, 0.1637, -0.8203, 0.4781, 0.0622, -0.7859,\n", + " -1.7128, 1.2858, -0.1512, -1.1712, -0.5319, 0.5512, -1.4268,\n", + " 0.2548, 2.9298, 0.7842, -0.1890, -0.8952, 0.1479, -0.0825,\n", + " -0.2476, -0.0274, 1.9373, 1.2561, -0.3870, -0.5059, 1.0771,\n", + " 0.7161, 0.2510, -0.7403, 0.8498, -0.2442, 1.5790, 0.2348,\n", + " 0.0841, 2.0146, 0.8968, 0.2489, -0.9310, 1.7749, -1.3079,\n", + " 0.7295, 0.8599, 0.0189, 1.6884, -0.5464, 1.6790, -0.7889,\n", + " 0.2329, 0.8090, -1.8518, -0.1529, -0.9284, -0.2809, 0.7448,\n", + " -0.4769, -2.0393, -1.5564, 0.1528, 1.8674, -0.0044, 1.1340,\n", + " 0.9800, -0.9253, -0.4846, 1.1367, 1.0322, 1.7122, 1.1982,\n", + " -0.0808, 1.1228, -0.9897, -0.7876, -0.2427, 1.1697, 0.0605,\n", + " 0.8717, -0.1591, -0.1623, -0.9304, -1.2984, 1.0778, 0.2340,\n", + " 0.4645, 0.5720, -0.0596, 1.1844, 1.2052, -0.6137, 1.2002,\n", + " 0.0232, -1.2396, -0.4822, 1.4322],\n", + " [ 0.6106, -1.1350, 1.0885, -0.6590, 1.1972, 1.4557, 0.4810,\n", + " 1.6009, -0.8303, 1.0261, 0.8802, -1.5162, -0.0897, 1.2317,\n", + " -0.1939, -0.1571, -0.1656, -0.7842, 0.8784, -0.0210, 0.4554,\n", + " -1.6171, 2.7235, 1.7553, 0.9623, -1.7567, 1.3346, -1.1969,\n", + " -0.0813, -0.7047, 0.2750, -1.3842, -0.2213, 0.2764, 0.1043,\n", + " -0.5828, -1.7812, 1.3242, 0.5183, 0.6814, 0.2051, 1.0487,\n", + " -1.4249, -0.5231, 0.8032, -0.7084, -0.3434, 0.2496, 1.2076,\n", + " -1.8979, 1.0625, 1.8654, -1.5041, 1.3455, 0.9795, 0.0224,\n", + " 1.3700, -1.2286, -0.1811, 0.3205, -0.8475, 0.5566, 3.2508,\n", + " 0.0792, -1.6544, -0.1888, 0.4788, -0.4847, 1.1109, -0.4203,\n", + " 1.2064, 1.8857, 1.7834, -1.5793, -1.9296, 1.1148, 0.9699,\n", + " 0.5867, 0.4809, -0.7841, -1.2050, 1.0947, 0.4143, -1.0028,\n", + " -0.4719, -1.3144, -0.3814, 0.7144, 0.9841, -0.6150, 1.0012,\n", + " -0.3843, 1.4516, -0.1925, 0.5752, -0.3950, 1.8073, 0.4228,\n", + " 0.1050, -0.2622, -0.6314, -0.9772, -1.6765, 1.1153, -0.1649,\n", + " 0.2459, 0.7048, -0.7455, -0.9480, 1.5142, 0.9590, -2.4351,\n", + " 0.6375, -0.8491, -0.3248, -1.4186, 0.4533, 2.3914, 0.2982,\n", + " -0.6675, -1.3668, 0.6553, 1.5135, -0.9250, -0.0102, 1.2277,\n", + " -0.7793, -0.7797, 0.3316, -0.6598, 0.0769, -0.8516, -0.1348,\n", + " -0.2179, -0.1757, 0.6741, -0.8688, 2.0086, -1.6593, -0.2141,\n", + " 0.3545, 0.7880, -1.1431, 0.8881, -1.2641, -0.1263, -0.3227,\n", + " -0.2829, -0.4671, -0.2100, 1.2006, -0.3661, 0.2646, 0.5560,\n", + " 0.5605, -0.6277, 0.2558, -0.4180, -0.6891, 0.5311, 0.7269,\n", + " -0.0822, -0.1792, -0.7969, -0.3833, 0.8326, -1.1468, -0.4691,\n", + " -2.6068, 0.0915, 1.0277, 0.1904, -0.4176, -0.8701, 1.1830,\n", + " -1.3595, -0.2752, 0.1232, 1.0058, 1.6270, 0.6452, 0.3897,\n", + " -0.8015, -0.2148, 1.2683, -0.4499, -0.1746, 0.8048, 0.7027,\n", + " 1.0295, 0.5364, -0.6913, -1.3428, -0.3413, -0.1128, -0.4508,\n", + " -0.2579, -0.3001, -1.0598, 0.1322, 1.2603, -0.5473, -0.0878,\n", + " 1.2313, -0.9989, -0.2355, -0.9956, 1.5185, 1.7814, -0.2523,\n", + " -0.7115, -0.2909, -0.5441, -0.3842, -1.5203, 0.4054, 0.3819,\n", + " 0.4151, 0.7575, 0.0335, 0.2037, -1.5995, -0.7745, 1.0326,\n", + " -1.2516, -0.6565, 0.0148, -0.7770, -1.0470, 0.6955, -1.2700,\n", + " 1.1997, 1.7305, -1.0536, 0.3510, -0.8766, -0.6916, 0.4194,\n", + " 0.9248, -1.7707, -0.9460, -0.9431, 0.5791, 0.6802, -0.2858,\n", + " -0.2782, 0.9059, -1.2144, -0.3521, -1.0666, 0.5210, 2.3597,\n", + " -0.6880, 1.1041, -1.7320, 1.5964]]], device='cuda:0',\n", + " grad_fn=), past_key_values=None, hidden_states=None, attentions=None)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "inp = torch.tensor([[1, 2]]).cuda()\n", + "model(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/transformers/models/zamba2/utils.py b/src/transformers/models/zamba2/utils.py new file mode 100644 index 0000000000..7e8dd139c2 --- /dev/null +++ b/src/transformers/models/zamba2/utils.py @@ -0,0 +1,71 @@ +from operator import itemgetter +from typing import Any, Dict, Iterable, Optional, Tuple, Union +import math +import torch + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + +def openai_gelu(x): + return gelu_impl(x) + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return ff * g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply + +# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter +@torch.jit.script +def erf_gelu(x): + return ( + x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) + ) + +def init_method_normal(sigma): + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + +def scaled_init_method_normal(sigma, num_layers): + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ \ No newline at end of file diff --git a/src/transformers/models/zamba2/zamba2_torch b/src/transformers/models/zamba2/zamba2_torch new file mode 160000 index 0000000000..c3a10f3973 --- /dev/null +++ b/src/transformers/models/zamba2/zamba2_torch @@ -0,0 +1 @@ +Subproject commit c3a10f3973ae7f83335b82598fb4a9e9089bfed3 diff --git a/src/transformers/models/zamba2/zamba2_torch_generation.ipynb b/src/transformers/models/zamba2/zamba2_torch_generation.ipynb new file mode 100644 index 0000000000..73d8fa4ff5 --- /dev/null +++ b/src/transformers/models/zamba2/zamba2_torch_generation.ipynb @@ -0,0 +1,790 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Instantiate small random model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Zarr-based strategies will not be registered because of missing packages\n", + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from mamba_model import MambaModel\n", + "from mamba_config import MambaConfig\n", + "import torch\n", + "\n", + "# Random init model\n", + "\n", + "config = MambaConfig(\n", + " num_layers = 3, # 5\n", + " hidden_size = 256, # 1024\n", + " mamba_headdim = 64,\n", + " mamba_ngroups = 1,\n", + " state_size = 16,\n", + " conv_dimension = 4,\n", + " expansion_factor = 2,\n", + " rms_norm = True,\n", + " bias = False,\n", + " use_mem_mlp = False, #True,\n", + " use_mem_rope = False, #True,\n", + " num_attention_heads = 16,\n", + " num_mem_heads = 0, #16,\n", + " num_mem_blocks = 0, #2,\n", + " vocab_size = 32000,\n", + " layer_mapping = [\"m\", \"m\", \"m\"]\n", + ")\n", + "\n", + "model = MambaModel(config = config, max_sequence_length = 4096)\n", + "model = model.cuda().half()\n", + "# Tokenizer\n", + "from megatron.tokenizer.tokenizer import _HFAutoTokenizer\n", + "from megatron.text_generation.tokenization import tokenize_prompts\n", + "tokenizer = _HFAutoTokenizer(\"mistralai/Mistral-7B-v0.1\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load 3B checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['--base-model-type', 'mamba', '--num-layers', '54', '--mamba-headdim', '64', '--mamba-ngroups', '1', '--hidden-size', '2560', '--state-size', '64', '--conv-dimension', '4', '--expansion-factor', '2', '--seq-length', '4096', '--max-position-embeddings', '4096', '--micro-batch-size', '3', '--global-batch-size', '384', '--lr', '2.0e-4', '--train-iters', '1_700_000', '--lr-decay-iters', '1_700_000', '--lr-decay-style', 'cosine', '--min-lr', '5.0e-5', '--lr-warmup-init', '0.0', '--weight-decay', '0.1', '--adam-beta2', '0.95', '--lr-warmup-iters', '5000', '--clip-grad', '1.0', '--bf16', '--recompute-granularity', 'selective', '--accumulate-allreduce-grads-in-fp32', '--attention-dropout', '0.0', '--hidden-dropout', '0.0', '--rms-norm', 'True', '--disable-bias-linear', '--num-mem-heads', '32', '--use-mem-mlp', '--num-attention-heads', '32', '--num-mem-blocks', '2', '--use-shared-block-lora', '--lora-rank', '128', '--use-distributed-optimizer', '--mamba-moe-layers', 'm', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'g', 'm', 'm', '--buffer-params', '96']\n", + "odict_keys(['embedding.weight', 'decoder.layers.0.mixer.dt_bias', 'decoder.layers.0.mixer.A_log', 'decoder.layers.0.mixer.D', 'decoder.layers.0.mixer.in_proj.0.weight', 'decoder.layers.0.mixer.conv1d.weight', 'decoder.layers.0.mixer.conv1d.bias', 'decoder.layers.0.mixer.norm.weight', 'decoder.layers.0.mixer.out_proj.weight', 'decoder.layers.0.norm.weight', 'decoder.layers.1.mixer.dt_bias', 'decoder.layers.1.mixer.A_log', 'decoder.layers.1.mixer.D', 'decoder.layers.1.mixer.in_proj.0.weight', 'decoder.layers.1.mixer.conv1d.weight', 'decoder.layers.1.mixer.conv1d.bias', 'decoder.layers.1.mixer.norm.weight', 'decoder.layers.1.mixer.out_proj.weight', 'decoder.layers.1.norm.weight', 'decoder.layers.2.mixer.dt_bias', 'decoder.layers.2.mixer.A_log', 'decoder.layers.2.mixer.D', 'decoder.layers.2.mixer.in_proj.0.weight', 'decoder.layers.2.mixer.conv1d.weight', 'decoder.layers.2.mixer.conv1d.bias', 'decoder.layers.2.mixer.norm.weight', 'decoder.layers.2.mixer.out_proj.weight', 'decoder.layers.2.norm.weight', 'decoder.layers.3.mixer.dt_bias', 'decoder.layers.3.mixer.A_log', 'decoder.layers.3.mixer.D', 'decoder.layers.3.mixer.in_proj.0.weight', 'decoder.layers.3.mixer.conv1d.weight', 'decoder.layers.3.mixer.conv1d.bias', 'decoder.layers.3.mixer.norm.weight', 'decoder.layers.3.mixer.out_proj.weight', 'decoder.layers.3.norm.weight', 'decoder.layers.4.mixer.dt_bias', 'decoder.layers.4.mixer.A_log', 'decoder.layers.4.mixer.D', 'decoder.layers.4.mixer.in_proj.0.weight', 'decoder.layers.4.mixer.conv1d.weight', 'decoder.layers.4.mixer.conv1d.bias', 'decoder.layers.4.mixer.norm.weight', 'decoder.layers.4.mixer.out_proj.weight', 'decoder.layers.4.norm.weight', 'decoder.layers.5.mixer.dt_bias', 'decoder.layers.5.mixer.A_log', 'decoder.layers.5.mixer.D', 'decoder.layers.5.mixer.in_proj.0.weight', 'decoder.layers.5.mixer.conv1d.weight', 'decoder.layers.5.mixer.conv1d.bias', 'decoder.layers.5.mixer.norm.weight', 'decoder.layers.5.mixer.out_proj.weight', 'decoder.layers.5.norm.weight', 'decoder.layers.6.mixer.dt_bias', 'decoder.layers.6.mixer.A_log', 'decoder.layers.6.mixer.D', 'decoder.layers.6.mixer.in_proj.0.weight', 'decoder.layers.6.mixer.conv1d.weight', 'decoder.layers.6.mixer.conv1d.bias', 'decoder.layers.6.mixer.norm.weight', 'decoder.layers.6.mixer.out_proj.weight', 'decoder.layers.6.norm.weight', 'decoder.layers.7.mixer.dt_bias', 'decoder.layers.7.mixer.A_log', 'decoder.layers.7.mixer.D', 'decoder.layers.7.mixer.in_proj.0.weight', 'decoder.layers.7.mixer.conv1d.weight', 'decoder.layers.7.mixer.conv1d.bias', 'decoder.layers.7.mixer.norm.weight', 'decoder.layers.7.mixer.out_proj.weight', 'decoder.layers.7.norm.weight', 'decoder.layers.8.mixer.dt_bias', 'decoder.layers.8.mixer.A_log', 'decoder.layers.8.mixer.D', 'decoder.layers.8.mixer.in_proj.0.weight', 'decoder.layers.8.mixer.conv1d.weight', 'decoder.layers.8.mixer.conv1d.bias', 'decoder.layers.8.mixer.norm.weight', 'decoder.layers.8.mixer.out_proj.weight', 'decoder.layers.8.norm.weight', 'decoder.layers.9.mixer.dt_bias', 'decoder.layers.9.mixer.A_log', 'decoder.layers.9.mixer.D', 'decoder.layers.9.mixer.in_proj.0.weight', 'decoder.layers.9.mixer.conv1d.weight', 'decoder.layers.9.mixer.conv1d.bias', 'decoder.layers.9.mixer.norm.weight', 'decoder.layers.9.mixer.out_proj.weight', 'decoder.layers.9.norm.weight', 'decoder.layers.10.mixer.dt_bias', 'decoder.layers.10.mixer.A_log', 'decoder.layers.10.mixer.D', 'decoder.layers.10.mixer.in_proj.0.weight', 'decoder.layers.10.mixer.conv1d.weight', 'decoder.layers.10.mixer.conv1d.bias', 'decoder.layers.10.mixer.norm.weight', 'decoder.layers.10.mixer.out_proj.weight', 'decoder.layers.10.norm.weight', 'decoder.layers.11.mixer.dt_bias', 'decoder.layers.11.mixer.A_log', 'decoder.layers.11.mixer.D', 'decoder.layers.11.mixer.in_proj.0.weight', 'decoder.layers.11.mixer.conv1d.weight', 'decoder.layers.11.mixer.conv1d.bias', 'decoder.layers.11.mixer.norm.weight', 'decoder.layers.11.mixer.out_proj.weight', 'decoder.layers.11.norm.weight', 'decoder.layers.12.mixer.dt_bias', 'decoder.layers.12.mixer.A_log', 'decoder.layers.12.mixer.D', 'decoder.layers.12.mixer.in_proj.0.weight', 'decoder.layers.12.mixer.conv1d.weight', 'decoder.layers.12.mixer.conv1d.bias', 'decoder.layers.12.mixer.norm.weight', 'decoder.layers.12.mixer.out_proj.weight', 'decoder.layers.12.norm.weight', 'decoder.layers.13.mixer.dt_bias', 'decoder.layers.13.mixer.A_log', 'decoder.layers.13.mixer.D', 'decoder.layers.13.mixer.in_proj.0.weight', 'decoder.layers.13.mixer.conv1d.weight', 'decoder.layers.13.mixer.conv1d.bias', 'decoder.layers.13.mixer.norm.weight', 'decoder.layers.13.mixer.out_proj.weight', 'decoder.layers.13.norm.weight', 'decoder.layers.14.mixer.dt_bias', 'decoder.layers.14.mixer.A_log', 'decoder.layers.14.mixer.D', 'decoder.layers.14.mixer.in_proj.0.weight', 'decoder.layers.14.mixer.conv1d.weight', 'decoder.layers.14.mixer.conv1d.bias', 'decoder.layers.14.mixer.norm.weight', 'decoder.layers.14.mixer.out_proj.weight', 'decoder.layers.14.norm.weight', 'decoder.layers.15.mixer.dt_bias', 'decoder.layers.15.mixer.A_log', 'decoder.layers.15.mixer.D', 'decoder.layers.15.mixer.in_proj.0.weight', 'decoder.layers.15.mixer.conv1d.weight', 'decoder.layers.15.mixer.conv1d.bias', 'decoder.layers.15.mixer.norm.weight', 'decoder.layers.15.mixer.out_proj.weight', 'decoder.layers.15.norm.weight', 'decoder.layers.16.mixer.dt_bias', 'decoder.layers.16.mixer.A_log', 'decoder.layers.16.mixer.D', 'decoder.layers.16.mixer.in_proj.0.weight', 'decoder.layers.16.mixer.conv1d.weight', 'decoder.layers.16.mixer.conv1d.bias', 'decoder.layers.16.mixer.norm.weight', 'decoder.layers.16.mixer.out_proj.weight', 'decoder.layers.16.norm.weight', 'decoder.layers.17.mixer.dt_bias', 'decoder.layers.17.mixer.A_log', 'decoder.layers.17.mixer.D', 'decoder.layers.17.mixer.in_proj.0.weight', 'decoder.layers.17.mixer.conv1d.weight', 'decoder.layers.17.mixer.conv1d.bias', 'decoder.layers.17.mixer.norm.weight', 'decoder.layers.17.mixer.out_proj.weight', 'decoder.layers.17.norm.weight', 'decoder.layers.18.mixer.dt_bias', 'decoder.layers.18.mixer.A_log', 'decoder.layers.18.mixer.D', 'decoder.layers.18.mixer.in_proj.0.weight', 'decoder.layers.18.mixer.conv1d.weight', 'decoder.layers.18.mixer.conv1d.bias', 'decoder.layers.18.mixer.norm.weight', 'decoder.layers.18.mixer.out_proj.weight', 'decoder.layers.18.norm.weight', 'decoder.layers.19.mixer.dt_bias', 'decoder.layers.19.mixer.A_log', 'decoder.layers.19.mixer.D', 'decoder.layers.19.mixer.in_proj.0.weight', 'decoder.layers.19.mixer.conv1d.weight', 'decoder.layers.19.mixer.conv1d.bias', 'decoder.layers.19.mixer.norm.weight', 'decoder.layers.19.mixer.out_proj.weight', 'decoder.layers.19.norm.weight', 'decoder.layers.20.mixer.dt_bias', 'decoder.layers.20.mixer.A_log', 'decoder.layers.20.mixer.D', 'decoder.layers.20.mixer.in_proj.0.weight', 'decoder.layers.20.mixer.conv1d.weight', 'decoder.layers.20.mixer.conv1d.bias', 'decoder.layers.20.mixer.norm.weight', 'decoder.layers.20.mixer.out_proj.weight', 'decoder.layers.20.norm.weight', 'decoder.layers.21.mixer.dt_bias', 'decoder.layers.21.mixer.A_log', 'decoder.layers.21.mixer.D', 'decoder.layers.21.mixer.in_proj.0.weight', 'decoder.layers.21.mixer.conv1d.weight', 'decoder.layers.21.mixer.conv1d.bias', 'decoder.layers.21.mixer.norm.weight', 'decoder.layers.21.mixer.out_proj.weight', 'decoder.layers.21.norm.weight', 'decoder.layers.22.mixer.dt_bias', 'decoder.layers.22.mixer.A_log', 'decoder.layers.22.mixer.D', 'decoder.layers.22.mixer.in_proj.0.weight', 'decoder.layers.22.mixer.conv1d.weight', 'decoder.layers.22.mixer.conv1d.bias', 'decoder.layers.22.mixer.norm.weight', 'decoder.layers.22.mixer.out_proj.weight', 'decoder.layers.22.norm.weight', 'decoder.layers.23.mixer.dt_bias', 'decoder.layers.23.mixer.A_log', 'decoder.layers.23.mixer.D', 'decoder.layers.23.mixer.in_proj.0.weight', 'decoder.layers.23.mixer.conv1d.weight', 'decoder.layers.23.mixer.conv1d.bias', 'decoder.layers.23.mixer.norm.weight', 'decoder.layers.23.mixer.out_proj.weight', 'decoder.layers.23.norm.weight', 'decoder.layers.24.mixer.dt_bias', 'decoder.layers.24.mixer.A_log', 'decoder.layers.24.mixer.D', 'decoder.layers.24.mixer.in_proj.0.weight', 'decoder.layers.24.mixer.conv1d.weight', 'decoder.layers.24.mixer.conv1d.bias', 'decoder.layers.24.mixer.norm.weight', 'decoder.layers.24.mixer.out_proj.weight', 'decoder.layers.24.norm.weight', 'decoder.layers.25.mixer.dt_bias', 'decoder.layers.25.mixer.A_log', 'decoder.layers.25.mixer.D', 'decoder.layers.25.mixer.in_proj.0.weight', 'decoder.layers.25.mixer.conv1d.weight', 'decoder.layers.25.mixer.conv1d.bias', 'decoder.layers.25.mixer.norm.weight', 'decoder.layers.25.mixer.out_proj.weight', 'decoder.layers.25.norm.weight', 'decoder.layers.26.mixer.dt_bias', 'decoder.layers.26.mixer.A_log', 'decoder.layers.26.mixer.D', 'decoder.layers.26.mixer.in_proj.0.weight', 'decoder.layers.26.mixer.conv1d.weight', 'decoder.layers.26.mixer.conv1d.bias', 'decoder.layers.26.mixer.norm.weight', 'decoder.layers.26.mixer.out_proj.weight', 'decoder.layers.26.norm.weight', 'decoder.layers.27.mixer.dt_bias', 'decoder.layers.27.mixer.A_log', 'decoder.layers.27.mixer.D', 'decoder.layers.27.mixer.in_proj.0.weight', 'decoder.layers.27.mixer.conv1d.weight', 'decoder.layers.27.mixer.conv1d.bias', 'decoder.layers.27.mixer.norm.weight', 'decoder.layers.27.mixer.out_proj.weight', 'decoder.layers.27.norm.weight', 'decoder.layers.28.mixer.dt_bias', 'decoder.layers.28.mixer.A_log', 'decoder.layers.28.mixer.D', 'decoder.layers.28.mixer.in_proj.0.weight', 'decoder.layers.28.mixer.conv1d.weight', 'decoder.layers.28.mixer.conv1d.bias', 'decoder.layers.28.mixer.norm.weight', 'decoder.layers.28.mixer.out_proj.weight', 'decoder.layers.28.norm.weight', 'decoder.layers.29.mixer.dt_bias', 'decoder.layers.29.mixer.A_log', 'decoder.layers.29.mixer.D', 'decoder.layers.29.mixer.in_proj.0.weight', 'decoder.layers.29.mixer.conv1d.weight', 'decoder.layers.29.mixer.conv1d.bias', 'decoder.layers.29.mixer.norm.weight', 'decoder.layers.29.mixer.out_proj.weight', 'decoder.layers.29.norm.weight', 'decoder.layers.30.mixer.dt_bias', 'decoder.layers.30.mixer.A_log', 'decoder.layers.30.mixer.D', 'decoder.layers.30.mixer.in_proj.0.weight', 'decoder.layers.30.mixer.conv1d.weight', 'decoder.layers.30.mixer.conv1d.bias', 'decoder.layers.30.mixer.norm.weight', 'decoder.layers.30.mixer.out_proj.weight', 'decoder.layers.30.norm.weight', 'decoder.layers.31.mixer.dt_bias', 'decoder.layers.31.mixer.A_log', 'decoder.layers.31.mixer.D', 'decoder.layers.31.mixer.in_proj.0.weight', 'decoder.layers.31.mixer.conv1d.weight', 'decoder.layers.31.mixer.conv1d.bias', 'decoder.layers.31.mixer.norm.weight', 'decoder.layers.31.mixer.out_proj.weight', 'decoder.layers.31.norm.weight', 'decoder.layers.32.mixer.dt_bias', 'decoder.layers.32.mixer.A_log', 'decoder.layers.32.mixer.D', 'decoder.layers.32.mixer.in_proj.0.weight', 'decoder.layers.32.mixer.conv1d.weight', 'decoder.layers.32.mixer.conv1d.bias', 'decoder.layers.32.mixer.norm.weight', 'decoder.layers.32.mixer.out_proj.weight', 'decoder.layers.32.norm.weight', 'decoder.layers.33.mixer.dt_bias', 'decoder.layers.33.mixer.A_log', 'decoder.layers.33.mixer.D', 'decoder.layers.33.mixer.in_proj.0.weight', 'decoder.layers.33.mixer.conv1d.weight', 'decoder.layers.33.mixer.conv1d.bias', 'decoder.layers.33.mixer.norm.weight', 'decoder.layers.33.mixer.out_proj.weight', 'decoder.layers.33.norm.weight', 'decoder.layers.34.mixer.dt_bias', 'decoder.layers.34.mixer.A_log', 'decoder.layers.34.mixer.D', 'decoder.layers.34.mixer.in_proj.0.weight', 'decoder.layers.34.mixer.conv1d.weight', 'decoder.layers.34.mixer.conv1d.bias', 'decoder.layers.34.mixer.norm.weight', 'decoder.layers.34.mixer.out_proj.weight', 'decoder.layers.34.norm.weight', 'decoder.layers.35.mixer.dt_bias', 'decoder.layers.35.mixer.A_log', 'decoder.layers.35.mixer.D', 'decoder.layers.35.mixer.in_proj.0.weight', 'decoder.layers.35.mixer.conv1d.weight', 'decoder.layers.35.mixer.conv1d.bias', 'decoder.layers.35.mixer.norm.weight', 'decoder.layers.35.mixer.out_proj.weight', 'decoder.layers.35.norm.weight', 'decoder.layers.36.mixer.dt_bias', 'decoder.layers.36.mixer.A_log', 'decoder.layers.36.mixer.D', 'decoder.layers.36.mixer.in_proj.0.weight', 'decoder.layers.36.mixer.conv1d.weight', 'decoder.layers.36.mixer.conv1d.bias', 'decoder.layers.36.mixer.norm.weight', 'decoder.layers.36.mixer.out_proj.weight', 'decoder.layers.36.norm.weight', 'decoder.layers.37.mixer.dt_bias', 'decoder.layers.37.mixer.A_log', 'decoder.layers.37.mixer.D', 'decoder.layers.37.mixer.in_proj.0.weight', 'decoder.layers.37.mixer.conv1d.weight', 'decoder.layers.37.mixer.conv1d.bias', 'decoder.layers.37.mixer.norm.weight', 'decoder.layers.37.mixer.out_proj.weight', 'decoder.layers.37.norm.weight', 'decoder.layers.38.mixer.dt_bias', 'decoder.layers.38.mixer.A_log', 'decoder.layers.38.mixer.D', 'decoder.layers.38.mixer.in_proj.0.weight', 'decoder.layers.38.mixer.conv1d.weight', 'decoder.layers.38.mixer.conv1d.bias', 'decoder.layers.38.mixer.norm.weight', 'decoder.layers.38.mixer.out_proj.weight', 'decoder.layers.38.norm.weight', 'decoder.layers.39.mixer.dt_bias', 'decoder.layers.39.mixer.A_log', 'decoder.layers.39.mixer.D', 'decoder.layers.39.mixer.in_proj.0.weight', 'decoder.layers.39.mixer.conv1d.weight', 'decoder.layers.39.mixer.conv1d.bias', 'decoder.layers.39.mixer.norm.weight', 'decoder.layers.39.mixer.out_proj.weight', 'decoder.layers.39.norm.weight', 'decoder.layers.40.mixer.dt_bias', 'decoder.layers.40.mixer.A_log', 'decoder.layers.40.mixer.D', 'decoder.layers.40.mixer.in_proj.0.weight', 'decoder.layers.40.mixer.conv1d.weight', 'decoder.layers.40.mixer.conv1d.bias', 'decoder.layers.40.mixer.norm.weight', 'decoder.layers.40.mixer.out_proj.weight', 'decoder.layers.40.norm.weight', 'decoder.layers.41.mixer.dt_bias', 'decoder.layers.41.mixer.A_log', 'decoder.layers.41.mixer.D', 'decoder.layers.41.mixer.in_proj.0.weight', 'decoder.layers.41.mixer.conv1d.weight', 'decoder.layers.41.mixer.conv1d.bias', 'decoder.layers.41.mixer.norm.weight', 'decoder.layers.41.mixer.out_proj.weight', 'decoder.layers.41.norm.weight', 'decoder.layers.42.mixer.dt_bias', 'decoder.layers.42.mixer.A_log', 'decoder.layers.42.mixer.D', 'decoder.layers.42.mixer.in_proj.0.weight', 'decoder.layers.42.mixer.conv1d.weight', 'decoder.layers.42.mixer.conv1d.bias', 'decoder.layers.42.mixer.norm.weight', 'decoder.layers.42.mixer.out_proj.weight', 'decoder.layers.42.norm.weight', 'decoder.layers.43.mixer.dt_bias', 'decoder.layers.43.mixer.A_log', 'decoder.layers.43.mixer.D', 'decoder.layers.43.mixer.in_proj.0.weight', 'decoder.layers.43.mixer.conv1d.weight', 'decoder.layers.43.mixer.conv1d.bias', 'decoder.layers.43.mixer.norm.weight', 'decoder.layers.43.mixer.out_proj.weight', 'decoder.layers.43.norm.weight', 'decoder.layers.44.mixer.dt_bias', 'decoder.layers.44.mixer.A_log', 'decoder.layers.44.mixer.D', 'decoder.layers.44.mixer.in_proj.0.weight', 'decoder.layers.44.mixer.conv1d.weight', 'decoder.layers.44.mixer.conv1d.bias', 'decoder.layers.44.mixer.norm.weight', 'decoder.layers.44.mixer.out_proj.weight', 'decoder.layers.44.norm.weight', 'decoder.layers.45.mixer.dt_bias', 'decoder.layers.45.mixer.A_log', 'decoder.layers.45.mixer.D', 'decoder.layers.45.mixer.in_proj.0.weight', 'decoder.layers.45.mixer.conv1d.weight', 'decoder.layers.45.mixer.conv1d.bias', 'decoder.layers.45.mixer.norm.weight', 'decoder.layers.45.mixer.out_proj.weight', 'decoder.layers.45.norm.weight', 'decoder.layers.46.mixer.dt_bias', 'decoder.layers.46.mixer.A_log', 'decoder.layers.46.mixer.D', 'decoder.layers.46.mixer.in_proj.0.weight', 'decoder.layers.46.mixer.conv1d.weight', 'decoder.layers.46.mixer.conv1d.bias', 'decoder.layers.46.mixer.norm.weight', 'decoder.layers.46.mixer.out_proj.weight', 'decoder.layers.46.norm.weight', 'decoder.layers.47.mixer.dt_bias', 'decoder.layers.47.mixer.A_log', 'decoder.layers.47.mixer.D', 'decoder.layers.47.mixer.in_proj.0.weight', 'decoder.layers.47.mixer.conv1d.weight', 'decoder.layers.47.mixer.conv1d.bias', 'decoder.layers.47.mixer.norm.weight', 'decoder.layers.47.mixer.out_proj.weight', 'decoder.layers.47.norm.weight', 'decoder.layers.48.mixer.dt_bias', 'decoder.layers.48.mixer.A_log', 'decoder.layers.48.mixer.D', 'decoder.layers.48.mixer.in_proj.0.weight', 'decoder.layers.48.mixer.conv1d.weight', 'decoder.layers.48.mixer.conv1d.bias', 'decoder.layers.48.mixer.norm.weight', 'decoder.layers.48.mixer.out_proj.weight', 'decoder.layers.48.norm.weight', 'decoder.layers.49.mixer.dt_bias', 'decoder.layers.49.mixer.A_log', 'decoder.layers.49.mixer.D', 'decoder.layers.49.mixer.in_proj.0.weight', 'decoder.layers.49.mixer.conv1d.weight', 'decoder.layers.49.mixer.conv1d.bias', 'decoder.layers.49.mixer.norm.weight', 'decoder.layers.49.mixer.out_proj.weight', 'decoder.layers.49.norm.weight', 'decoder.layers.50.mixer.dt_bias', 'decoder.layers.50.mixer.A_log', 'decoder.layers.50.mixer.D', 'decoder.layers.50.mixer.in_proj.0.weight', 'decoder.layers.50.mixer.conv1d.weight', 'decoder.layers.50.mixer.conv1d.bias', 'decoder.layers.50.mixer.norm.weight', 'decoder.layers.50.mixer.out_proj.weight', 'decoder.layers.50.norm.weight', 'decoder.layers.51.mixer.dt_bias', 'decoder.layers.51.mixer.A_log', 'decoder.layers.51.mixer.D', 'decoder.layers.51.mixer.in_proj.0.weight', 'decoder.layers.51.mixer.conv1d.weight', 'decoder.layers.51.mixer.conv1d.bias', 'decoder.layers.51.mixer.norm.weight', 'decoder.layers.51.mixer.out_proj.weight', 'decoder.layers.51.norm.weight', 'decoder.layers.52.mixer.dt_bias', 'decoder.layers.52.mixer.A_log', 'decoder.layers.52.mixer.D', 'decoder.layers.52.mixer.in_proj.0.weight', 'decoder.layers.52.mixer.conv1d.weight', 'decoder.layers.52.mixer.conv1d.bias', 'decoder.layers.52.mixer.norm.weight', 'decoder.layers.52.mixer.out_proj.weight', 'decoder.layers.52.norm.weight', 'decoder.layers.53.mixer.dt_bias', 'decoder.layers.53.mixer.A_log', 'decoder.layers.53.mixer.D', 'decoder.layers.53.mixer.in_proj.0.weight', 'decoder.layers.53.mixer.conv1d.weight', 'decoder.layers.53.mixer.conv1d.bias', 'decoder.layers.53.mixer.norm.weight', 'decoder.layers.53.mixer.out_proj.weight', 'decoder.layers.53.norm.weight', 'decoder.blocks.0.sa.mixer.linear_qkv.weight', 'decoder.blocks.0.sa.mixer.linear_proj.weight', 'decoder.blocks.0.sa.norm.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1.weight', 'decoder.blocks.0.mlp.mixer.linear_fc2.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.0.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.1.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.2.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.3.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.4.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.5.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.6.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.7.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_A_list.8.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.0.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.1.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.2.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.3.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.4.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.5.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.6.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.7.weight', 'decoder.blocks.0.mlp.mixer.linear_fc1_lora_B_list.8.weight', 'decoder.blocks.0.mlp.norm.weight', 'decoder.blocks.1.sa.mixer.linear_qkv.weight', 'decoder.blocks.1.sa.mixer.linear_proj.weight', 'decoder.blocks.1.sa.norm.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1.weight', 'decoder.blocks.1.mlp.mixer.linear_fc2.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.0.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.1.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.2.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.3.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.4.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.5.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.6.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.7.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_A_list.8.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.0.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.1.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.2.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.3.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.4.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.5.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.6.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.7.weight', 'decoder.blocks.1.mlp.mixer.linear_fc1_lora_B_list.8.weight', 'decoder.blocks.1.mlp.norm.weight', 'decoder.block_map.6.weight', 'decoder.block_map.12.weight', 'decoder.block_map.18.weight', 'decoder.block_map.24.weight', 'decoder.block_map.30.weight', 'decoder.block_map.36.weight', 'decoder.block_map.42.weight', 'decoder.block_map.47.weight', 'decoder.block_map.51.weight', 'decoder.final_layernorm.weight'])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Zarr-based strategies will not be registered because of missing packages\n", + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n", + "from mamba_model import MambaModel\n", + "from mamba_config import MambaConfig\n", + "import torch\n", + "import copy\n", + "\n", + "# checkpoint_name = \"/workspace/zamba2_tiny_test/iter_0000015/mp_rank_00/model_optim_rng.pt\"\n", + "checkpoint_name = \"/checkpoints/zamba_3B_attempt_2/iter_1235000/mp_rank_00/model_optim_rng.pt\"\n", + "state_dict = torch.load(checkpoint_name, map_location = \"cpu\")\n", + "# config_path = \"/workspace/zamba2_tiny_test/config.sh\"\n", + "config_path = \"/workspace/Mamba-MoE/examples/mamba_memory_block/zamba_3B_attempt3.sh\"\n", + "\n", + "def extract_keyword_args(filestr, keyword):\n", + " gpt_split = filestr.split(keyword)\n", + " if len(gpt_split) <=1:\n", + " raise ValueError(\"Config provided does not have a GPT_ARGS variable provided\")\n", + " arg_splits = gpt_split[1].split(\"\\\"\")\n", + " gpt_args = arg_splits[1]\n", + " gpt_args = gpt_args.replace(\"\\n\",\"\").replace(\"\\\\\",\"\").replace(\"\\t\",\"\")\n", + " gpt_args = ' '.join(gpt_args.split())\n", + " return gpt_args.strip().split(\" \")\n", + " \n", + "with open(config_path,\"r\") as f:\n", + " filestr = f.read()\n", + "config_args = extract_keyword_args(filestr, \"GPT_ARGS\")\n", + "print(config_args)\n", + "\n", + "config = MambaConfig(\n", + " num_layers = 54,\n", + " hidden_size = 2560,\n", + " mamba_headdim = 64,\n", + " mamba_ngroups = 1,\n", + " state_size = 64,\n", + " conv_dimension = 4,\n", + " expansion_factor = 2,\n", + " rms_norm = True,\n", + " bias = False,\n", + " use_mem_mlp = True,\n", + " num_attention_heads = 32,\n", + " num_mem_heads = 32,\n", + " num_mem_blocks = 2,\n", + " use_shared_block_lora = True,\n", + " lora_rank = 128,\n", + " vocab_size = 32000,\n", + " layer_mapping = ['m', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'g', 'm', 'm']\n", + " \n", + ")\n", + "#print(config[\"num_layers\"])\n", + "# #layer_mapping = [\"r\", \"r\", \"g\", \"r\", \"r\"]\n", + "model = MambaModel(config = config, max_sequence_length = 4096)\n", + "\n", + "checkpoint_state_dict = dict(copy.deepcopy(state_dict[\"model\"]))\n", + "model_state_dict = model.state_dict()\n", + "del checkpoint_state_dict[\"embedding.position_embeddings.weight\"] # delete pos embeddings from the state dict\n", + "checkpoint_state_dict[\"embedding.weight\"] = checkpoint_state_dict[\"embedding.word_embeddings.weight\"].clone()\n", + "del checkpoint_state_dict[\"embedding.word_embeddings.weight\"]\n", + "# delete the extra state\n", + "for k in list(checkpoint_state_dict.keys()):\n", + " if \"extra_state\" in k:\n", + " del checkpoint_state_dict[k]\n", + " if \"buffer_params\" in k:\n", + " del checkpoint_state_dict[k]\n", + " # okay let's leave this out. I was wanting to do key matching\n", + "# for (k_c, k_m) in zip(checkpoint_state_dict.keys(), model_state_dict.keys()):\n", + "# model_state_dict[k_m].data = checkpoint_state_dict[k_c].clone().data\n", + "# #print(checkpoint_state_dict[k_c].clone().data.dtype,checkpoint_state_dict[k_c].clone().data.device)\n", + "print(model_state_dict.keys())\n", + "model.load_state_dict(checkpoint_state_dict)\n", + "model = model.cuda()#.float()\n", + "# Tokenizer\n", + "from megatron.tokenizer.tokenizer import _HFAutoTokenizer\n", + "from megatron.text_generation.tokenization import tokenize_prompts\n", + "tokenizer = _HFAutoTokenizer(\"mistralai/Mistral-7B-v0.1\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Greedy generation without cache\n", + "prompt = 'Hello, how are you?'\n", + "tokens_to_generate = 10\n", + "\n", + "prompts_tokens = [tokenizer.tokenize(prompt)]\n", + "text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)\n", + "print(text_ids.shape)\n", + "model.eval()\n", + "with torch.no_grad():\n", + " for _ in range(tokens_to_generate):\n", + " out = model(text_ids)\n", + " out_last = out[:, -1]\n", + " id = torch.argmax(out_last)[None, None]\n", + " text_ids = torch.cat((text_ids, id), dim=0)\n", + "print(text_ids.shape)\n", + "text_ids = text_ids.transpose(0, 1)[0]\n", + "tokenizer.detokenize(text_ids.cpu().numpy().tolist())" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# Adapted from megatron/text_generation/forward_step.py\n", + "# and megatron/core/inference_params.py\n", + "\n", + "from collections.abc import Iterable\n", + "\n", + "class InferenceParams:\n", + " \"\"\"Inference parameters that are passed to the main model in order\n", + " to efficienly calculate and store the context during inference.\"\"\"\n", + "\n", + " def __init__(self, max_batch_size, max_sequence_length):\n", + " self.max_sequence_length = max_sequence_length\n", + " self.max_batch_size = max_batch_size\n", + " self.sequence_len_offset = 0\n", + " self.batch_size_offset = 0\n", + " self.key_value_memory_dict = {}\n", + " self.key_value_memory_dict_mamba = {}\n", + "\n", + " def swap_key_value_dict(self, batch_idx):\n", + " \"swap between batches\"\n", + " if len(self.key_value_memory_dict) == 0:\n", + " raise ValueError(\"should not swap when dict in empty\")\n", + "\n", + " for layer_number in self.key_value_memory_dict.keys():\n", + " inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]\n", + " assert (\n", + " len(batch_idx) == inference_key_memory.shape[1]\n", + " ) # make sure batch size is the same\n", + " new_inference_key_memory = inference_key_memory[:, batch_idx]\n", + " new_inference_value_memory = inference_value_memory[:, batch_idx]\n", + " self.key_value_memory_dict[layer_number] = (\n", + " new_inference_key_memory,\n", + " new_inference_value_memory,\n", + " )\n", + "\n", + " \n", + "\n", + "\n", + "class ForwardStep:\n", + " \"\"\"Forward step function.\n", + " We use a class here to hide the inference parameters\n", + " from the outside caller.\"\"\"\n", + "\n", + " def __init__(self, model, max_batch_size, max_sequence_length):\n", + " \"\"\"Set values so we don't need to do it multiple times.\"\"\"\n", + " # Make sure model is in eval mode.\n", + " assert not isinstance(model, Iterable), \\\n", + " 'interleaving schedule is not supported for inference'\n", + " model.eval()\n", + " self.model = model\n", + " # Initialize inference parameters.\n", + " self.inference_params = InferenceParams(max_batch_size,\n", + " max_sequence_length)\n", + "\n", + "\n", + " def __call__(self, tokens):\n", + " \"\"\"Invocation of the forward methods. Note that self.inference_params\n", + " is being modified by the forward step.\"\"\"\n", + " # Run a simple forward pass.\n", + " logits = model(tokens, inference_params=self.inference_params)\n", + " \n", + " # if self.inference_params.sequence_len_offset > 0:\n", + " # for k in range(1, 50):\n", + " # temp_list = list(self.inference_params.key_value_memory_dict_mamba[k])\n", + " # temp_list[0] = temp_list[0] * 0\n", + " # temp_list[1] = temp_list[1] * 0\n", + " # self.inference_params.key_value_memory_dict_mamba[k] = tuple(temp_list)\n", + " \n", + " # Update the sequence length offset.\n", + " self.inference_params.sequence_len_offset += tokens.size(0)\n", + "\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([17, 1])\n" + ] + }, + { + "data": { + "text/plain": [ + "' Hello, how are you?\\nI am fine, thank you.\\nHow'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# greedy generation with cache\n", + "\n", + "prompt = 'Hello, how are you?'\n", + "tokens_to_generate = 10\n", + "batch_size = 1\n", + "\n", + "prompts_tokens = [tokenizer.tokenize(prompt)]\n", + "text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)\n", + "prompt_length = text_ids.shape[0]\n", + "max_sequence_length = prompt_length + tokens_to_generate\n", + "\n", + "# allocate tensor for full sequence\n", + "tokens = torch.zeros((max_sequence_length, batch_size), dtype=torch.int64, device=torch.cuda.current_device())\n", + "tokens[:prompt_length].copy_(text_ids)\n", + "\n", + "forward_step = ForwardStep(model, batch_size, max_sequence_length)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " prev_context_length = 0\n", + " for context_length in range(prompt_length, max_sequence_length):\n", + " # Pick the slice that we need to pass through the network.\n", + " tokens2use = tokens[prev_context_length:context_length]\n", + " logits = forward_step(tokens2use)\n", + " last_token_logits = logits[:, -1]\n", + " new_id = torch.argmax(last_token_logits)\n", + " tokens[context_length, 0] = new_id\n", + " prev_context_length = context_length\n", + "\n", + "print(tokens.shape)\n", + "tokens = tokens.transpose(0, 1)[0]\n", + "tokenizer.detokenize(tokens.cpu().numpy().tolist())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimize rope" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([29, 1])\n", + "torch.Size([1, 29, 32000])\n", + "tensor([[-4.7043, 4.1115, -1.6878, ..., -4.9308, -4.6628, -3.5261],\n", + " [-7.4544, 1.9833, -5.7091, ..., -6.7184, -5.4518, -6.2321],\n", + " [-7.2265, 0.9190, -5.1137, ..., -5.7556, -4.5535, -5.0144],\n", + " ...,\n", + " [-6.3463, 2.1663, -5.2459, ..., -7.2303, -5.1301, -6.2166],\n", + " [-7.8079, 2.8875, -7.0855, ..., -8.5312, -6.3525, -6.8940],\n", + " [-8.7074, 5.8500, -5.4719, ..., -8.3026, -6.2768, -6.7983]],\n", + " device='cuda:0')\n" + ] + } + ], + "source": [ + "# Forward pass with old rope\n", + "prompt = 'This is a prompt that I am writing just to check that the output logits of the new rope implementation agree with those of the old implementation'\n", + "\n", + "prompts_tokens = [tokenizer.tokenize(prompt)]\n", + "text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)\n", + "print(text_ids.shape)\n", + "model.eval()\n", + "with torch.no_grad():\n", + " out = model(text_ids)\n", + "print(out.shape)\n", + "print(out[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_116442/2377630929.py:5: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/pytorch/pytorch/torch/csrc/tensor/python_tensor.cpp:83.)\n", + " text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([29, 1])\n", + "torch.Size([1, 29, 32000])\n", + "tensor([[-4.7043, 4.1115, -1.6878, ..., -4.9308, -4.6628, -3.5261],\n", + " [-7.4544, 1.9833, -5.7091, ..., -6.7184, -5.4518, -6.2321],\n", + " [-7.2265, 0.9190, -5.1137, ..., -5.7556, -4.5535, -5.0144],\n", + " ...,\n", + " [-6.3463, 2.1663, -5.2459, ..., -7.2303, -5.1301, -6.2166],\n", + " [-7.8079, 2.8875, -7.0855, ..., -8.5312, -6.3525, -6.8940],\n", + " [-8.7074, 5.8500, -5.4719, ..., -8.3026, -6.2768, -6.7983]],\n", + " device='cuda:0')\n" + ] + } + ], + "source": [ + "# Forward pass with new rope\n", + "prompt = 'This is a prompt that I am writing just to check that the output logits of the new rope implementation agree with those of the old implementation'\n", + "\n", + "prompts_tokens = [tokenizer.tokenize(prompt)]\n", + "text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)\n", + "print(text_ids.shape)\n", + "model.eval()\n", + "with torch.no_grad():\n", + " out = model(text_ids)\n", + "print(out.shape)\n", + "print(out[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scratchwork" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 7]) 7\n" + ] + } + ], + "source": [ + "prompt = 'Hello, how are you?'\n", + "tokens_to_generate = 10\n", + "return_output_log_probs = False\n", + "top_k_sampling = 1\n", + "top_p_sampling=0.0,\n", + "top_p_decay=0.0,\n", + "top_p_bound=0.0,\n", + "temperature=1.0,\n", + "use_eod_token_for_early_termination=True,\n", + "stop_on_double_eol=False,\n", + "stop_on_eol=False,\n", + "prevent_newline_after_colon=False,\n", + "random_seed=-1,\n", + "return_logits=False\n", + "\n", + "# some lines below are adapted from Mamba-MoE/tasks/msdp/prompt.py and api.py\n", + "\n", + "\n", + "# Tokenize\n", + "context_tokens_tensor = [tokenizer.tokenize(prompt)]\n", + "context_tokens_tensor = torch.cuda.LongTensor(context_tokens_tensor)\n", + "context_length_tensor = context_tokens_tensor.shape[1]\n", + "print(context_tokens_tensor.shape, context_length_tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# simplified generation\n", + "\n", + "import torch.nn.functional as F\n", + "\n", + "def generate_tokens_probs_and_return_on_first_stage(\n", + " model, tokens, lengths,\n", + " return_output_log_probs=False,\n", + " top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,\n", + " temperature=1.0,\n", + " use_eod_token_for_early_termination=True,\n", + " stop_on_double_eol=False,\n", + " stop_on_eol=False,\n", + " prevent_newline_after_colon=True\n", + " ):\n", + " \"\"\"Main token generation function.\n", + " Arguments:\n", + " model: no interleaving is supported.\n", + " tokens: prompt tokens extended to be of size [b, max-sequence-length]\n", + " lengths: original prompt length, size: [b]\n", + " return_output_log_probs: flag to calculate the log probability of\n", + " the generated tokens. Note that the log probability is the one\n", + " from the original logit.\n", + " top_k, top_p: top-k and top-p sampling parameters.\n", + " Note that top-k = 1 is gready. Also, these paramters are\n", + " exclusive meaning that:\n", + " if top-k > 0 then we expect top-p=0.\n", + " if top-p > 0 then we check for top-k=0.\n", + " temperature: sampling temperature.\n", + " use_eod_token_for_early_termination: if True, do early termination if\n", + " all the sequences have reached this token.\n", + " prevent_newline_after_colon: if True, it will disable generating new line \\n after :\n", + " Outputs: Note that is size is adjusted to a lower value than\n", + " max-sequence-length if generation is terminated early.\n", + " tokens: prompt and generated tokens. size: [b, :]\n", + " generated_sequence_lengths: total length (including prompt) of\n", + " the generated sequence. size: [b]\n", + " output_log_probs: log probability of the selected tokens. size: [b, s]\n", + " \"\"\"\n", + "\n", + " batch_size = tokens.size(0)\n", + " min_prompt_length = lengths\n", + " max_sequence_length = tokens.size(1)\n", + "\n", + " # if max_sequence_length > args.max_position_embeddings:\n", + " # raise ValueError(\"Length of prompt + tokens_to_generate longer than allowed\")\n", + "\n", + " forward_step = ForwardStep(model, batch_size, max_sequence_length)\n", + "\n", + " # =============\n", + " # Run infernece\n", + " # =============\n", + "\n", + " with torch.no_grad():\n", + " prev_context_length = 0\n", + " for context_length in range(min_prompt_length, max_sequence_length):\n", + " # Pick the slice that we need to pass through the network.\n", + " tokens2use = tokens[:, prev_context_length:context_length]\n", + " # logits will be meanigful only in the last pipeline stage.\n", + " logits = forward_step(tokens2use)\n", + " # Sample.\n", + " last_token_logits = logits[:, -1, :]\n", + " new_sample = torch.argmax(last_token_logits)\n", + " ############# isn't tokens shorter than this? shouldn't do concat? can we actually allocate the total tokens tensor to memory, or unnecessary?\n", + " tokens[0, context_length] = new_sample\n", + " # Update the context length for the next token generation.\n", + " prev_context_length = context_length\n", + "\n", + " return tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "not enough values to unpack (expected 4, got 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 4\u001b[0m tokens, lengths, output_log_probs, logits \u001b[38;5;241m=\u001b[39m generate_tokens_probs_and_return_on_first_stage(\n\u001b[1;32m 5\u001b[0m model, context_tokens_tensor, context_length_tensor,\n\u001b[1;32m 6\u001b[0m return_output_log_probs\u001b[38;5;241m=\u001b[39mreturn_output_log_probs,\n\u001b[1;32m 7\u001b[0m top_k\u001b[38;5;241m=\u001b[39mtop_k_sampling,\n\u001b[1;32m 8\u001b[0m top_p\u001b[38;5;241m=\u001b[39mtop_p_sampling,\n\u001b[1;32m 9\u001b[0m top_p_decay\u001b[38;5;241m=\u001b[39mtop_p_decay,\n\u001b[1;32m 10\u001b[0m top_p_bound\u001b[38;5;241m=\u001b[39mtop_p_bound,\n\u001b[1;32m 11\u001b[0m temperature\u001b[38;5;241m=\u001b[39mtemperature,\n\u001b[1;32m 12\u001b[0m use_eod_token_for_early_termination\u001b[38;5;241m=\u001b[39muse_eod_token_for_early_termination,\n\u001b[1;32m 13\u001b[0m stop_on_double_eol\u001b[38;5;241m=\u001b[39mstop_on_double_eol,\n\u001b[1;32m 14\u001b[0m stop_on_eol\u001b[38;5;241m=\u001b[39mstop_on_eol,\n\u001b[1;32m 15\u001b[0m prevent_newline_after_colon\u001b[38;5;241m=\u001b[39mprevent_newline_after_colon)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# tokens, prompts_plus_generations, prompts_plus_generations_segments = \\\u001b[39;00m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# detokenize_generations(tokens, lengths, True)\u001b[39;00m\n\u001b[1;32m 20\u001b[0m prompts_plus_generations \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mdetokenize(tokens\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mtolist())\n", + "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 4, got 1)" + ] + } + ], + "source": [ + "# Generate\n", + "model.eval()\n", + "with torch.no_grad():\n", + " tokens, lengths, output_log_probs, logits = generate_tokens_probs_and_return_on_first_stage(\n", + " model, context_tokens_tensor, context_length_tensor,\n", + " return_output_log_probs=return_output_log_probs,\n", + " top_k=top_k_sampling,\n", + " top_p=top_p_sampling,\n", + " top_p_decay=top_p_decay,\n", + " top_p_bound=top_p_bound,\n", + " temperature=temperature,\n", + " use_eod_token_for_early_termination=use_eod_token_for_early_termination,\n", + " stop_on_double_eol=stop_on_double_eol,\n", + " stop_on_eol=stop_on_eol,\n", + " prevent_newline_after_colon=prevent_newline_after_colon)\n", + "\n", + "\n", + "# tokens, prompts_plus_generations, prompts_plus_generations_segments = \\\n", + "# detokenize_generations(tokens, lengths, True)\n", + "prompts_plus_generations = tokenizer.detokenize(tokens.transpose(0, 1)[0].cpu().numpy().tolist())\n", + "print(prompts_plus_generations)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# main generation function taken from megatron/text_generation/generation.py\n", + "\n", + "from megatron.text_generation.sampling import sample\n", + "import torch.nn.functional as F\n", + "\n", + "def generate_tokens_probs_and_return_on_first_stage(\n", + " model, tokens, lengths,\n", + " return_output_log_probs=False,\n", + " top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,\n", + " temperature=1.0,\n", + " use_eod_token_for_early_termination=True,\n", + " stop_on_double_eol=False,\n", + " stop_on_eol=False,\n", + " prevent_newline_after_colon=True\n", + " ):\n", + " \"\"\"Main token generation function.\n", + " Arguments:\n", + " model: no interleaving is supported.\n", + " tokens: prompt tokens extended to be of size [b, max-sequence-length]\n", + " lengths: original prompt length, size: [b]\n", + " return_output_log_probs: flag to calculate the log probability of\n", + " the generated tokens. Note that the log probability is the one\n", + " from the original logit.\n", + " top_k, top_p: top-k and top-p sampling parameters.\n", + " Note that top-k = 1 is gready. Also, these paramters are\n", + " exclusive meaning that:\n", + " if top-k > 0 then we expect top-p=0.\n", + " if top-p > 0 then we check for top-k=0.\n", + " temperature: sampling temperature.\n", + " use_eod_token_for_early_termination: if True, do early termination if\n", + " all the sequences have reached this token.\n", + " prevent_newline_after_colon: if True, it will disable generating new line \\n after :\n", + " Outputs: Note that is size is adjusted to a lower value than\n", + " max-sequence-length if generation is terminated early.\n", + " tokens: prompt and generated tokens. size: [b, :]\n", + " generated_sequence_lengths: total length (including prompt) of\n", + " the generated sequence. size: [b]\n", + " output_log_probs: log probability of the selected tokens. size: [b, s]\n", + " \"\"\"\n", + "\n", + " batch_size = tokens.size(0)\n", + " min_prompt_length = lengths\n", + " max_sequence_length = tokens.size(1)\n", + "\n", + " # if max_sequence_length > args.max_position_embeddings:\n", + " # raise ValueError(\"Length of prompt + tokens_to_generate longer than allowed\")\n", + "\n", + " forward_step = ForwardStep(model, batch_size, max_sequence_length)\n", + "\n", + " # if hasattr(args, 'eos_id'):\n", + " # termination_id = args.eos_id\n", + " # else:\n", + " termination_id = tokenizer.eod\n", + "\n", + " # ===================\n", + " # Pre-allocate memory\n", + " # ===================\n", + "\n", + " # Log probability of the sequence (prompt + generated tokens).\n", + " output_log_probs = None\n", + " output_log_probs_size = (batch_size, max_sequence_length - 1)\n", + " # Lengths of generated seuquence including including prompts.\n", + " generated_sequence_lengths = None\n", + "\n", + " if return_output_log_probs:\n", + " output_log_probs = torch.empty(output_log_probs_size,\n", + " dtype=torch.float32,\n", + " device=torch.cuda.current_device())\n", + " generated_sequence_lengths = torch.ones(\n", + " batch_size, dtype=torch.int64,\n", + " device=torch.cuda.current_device()) * max_sequence_length\n", + " \n", + " # Whether we have reached a termination id.\n", + " is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,\n", + " device=torch.cuda.current_device())\n", + "\n", + " # =============\n", + " # Run infernece\n", + " # =============\n", + "\n", + " with torch.no_grad():\n", + " prev_context_length = 0\n", + " if not lengths == min_prompt_length:\n", + " print(\"We want all prompt lengths to be the same, otherwise it seems from the algo below that prompts longer than `min_prompt_length` will be cut.\")\n", + " for context_length in range(min_prompt_length, max_sequence_length):\n", + "\n", + " # Pick the slice that we need to pass through the network.\n", + " tokens2use = tokens[:, prev_context_length:context_length]\n", + "\n", + " # logits will be meanigful only in the last pipeline stage.\n", + " logits = forward_step(tokens2use)\n", + "\n", + " \n", + " if prevent_newline_after_colon:\n", + " logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\\n')[0]] = -1e10 # disable \"\\n\" after \":\"\n", + " assert logits is not None\n", + "\n", + " # Sample.\n", + " last_token_logits = logits[:, -1, :]\n", + " new_sample = sample(last_token_logits,\n", + " top_k=top_k,\n", + " top_p=top_p,\n", + " temperature=temperature,\n", + " vocab_size=tokenizer.vocab_size)\n", + " if top_p > 0.0 and top_p_decay > 0.0:\n", + " top_p = top_p * top_p_decay\n", + " if top_p_bound > 0.0:\n", + " top_p = max(top_p, top_p_bound)\n", + "\n", + " # If a prompt length is smaller or equal th current context\n", + " # length, it means we have started generating tokens\n", + " started = lengths <= context_length\n", + " # Update the tokens.\n", + " tokens[started, context_length] = new_sample[started]\n", + "\n", + " # Calculate the log probabilities.\n", + " if return_output_log_probs:\n", + " log_probs = F.log_softmax(logits, dim=2)\n", + " if return_output_log_probs:\n", + " # Pick the tokens that we need to get the log\n", + " # probabilities for. Note that next input token is\n", + " # the token which we selected in the current logits,\n", + " # so shift by 1.\n", + " indices = torch.unsqueeze(\n", + " tokens[\n", + " :,\n", + " (prev_context_length + 1):(context_length + 1)],\n", + " 2)\n", + " output_log_probs[:,\n", + " prev_context_length:context_length] = \\\n", + " torch.gather(log_probs, 2, indices).squeeze(2)\n", + "\n", + " # Update the context length for the next token generation.\n", + " prev_context_length = context_length\n", + "\n", + " # Check if all the sequences have hit the termination_id.\n", + " done = None\n", + " \n", + " # TODO(rprenger) These stopping methods are tokenizer dependent\n", + " # instead tokenization should be in the inference loop so stop sequences can be used\n", + " if stop_on_double_eol:\n", + " hit_double_eol = (new_sample == 628).byte() & started.byte()\n", + " hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()\n", + " done_token = hit_double_eol | hit_two_eols\n", + " elif stop_on_eol:\n", + " hit_double_eol = (new_sample == 628).byte() & started.byte()\n", + " hit_eol = (new_sample == 198).byte() & started.byte()\n", + " done_token = hit_double_eol | hit_eol\n", + " else: \n", + " done_token = (new_sample == termination_id).byte() & \\\n", + " started.byte()\n", + " \n", + " just_finished = (done_token & ~is_generation_done).bool()\n", + " generated_sequence_lengths[just_finished.view(-1)] = \\\n", + " context_length + 1\n", + " is_generation_done = is_generation_done | done_token\n", + " done = torch.all(is_generation_done)\n", + " if use_eod_token_for_early_termination and done:\n", + " break\n", + " \n", + "\n", + " tokens = tokens[:, :(max_sequence_length + 1)]\n", + "\n", + " if return_output_log_probs:\n", + " output_log_probs = output_log_probs[:, :context_length]\n", + "\n", + " return tokens, generated_sequence_lengths, output_log_probs, None" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}