Skip to content

Commit

Permalink
make fixup fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Nov 11, 2024
1 parent 549d4cb commit 987bba9
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 151 deletions.
3 changes: 2 additions & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Expand Down Expand Up @@ -304,7 +305,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)

* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
<Tip>

FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
Expand Down
59 changes: 22 additions & 37 deletions src/transformers/models/zamba2/configuration_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math


from ...configuration_utils import PretrainedConfig

Expand All @@ -30,19 +30,16 @@ class Zamba2Config(PretrainedConfig):
Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Zamba2 model.
[Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Zamba2Model`]
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 54):
Expand All @@ -52,7 +49,7 @@ class Zamba2Config(PretrainedConfig):
mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
mamba_ngroups (`int`, *optional*, defaults to 8):
mamba_ngroups (`int`, *optional*, defaults to 1):
Number of groups for the evolution matrices of mamba 2.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
Expand All @@ -62,16 +59,10 @@ class Zamba2Config(PretrainedConfig):
Minimum clamping value of the `dt_proj.bias` layer initialization.
time_step_limit (`tuple`, *optional*):
Accepted range of time step values.
mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
n_mamba_heads (`int`, *optional*, defaults to 1):
Number of heads for the evolution matrices of mamba 2.
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
hidden_mamba_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) adjacent to the mamba conv.
chunk_size (`int`, *optional*, defaults to 256):
Size of the chunks that will comprise the sequence.
add_bias_linear (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -101,11 +92,11 @@ class Zamba2Config(PretrainedConfig):
Rank of the LoRA in the shared MLP and shared attention layers.
use_mem_rope (`bool`, *optional*, defaults to `False`):
If True, includes RoPE in the shared attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
rope_theta (`float`, *optional*, defaults to `10000.0`):
The base period of the RoPE embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
Expand All @@ -122,6 +113,16 @@ class Zamba2Config(PretrainedConfig):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
use_long_context (`bool`, *optional*, defaults to `False`):
Activates the context-extended version of Zamba by modifying RoPE.
```python
>>> from transformers import Zamba2Model, Zamba2Config
>>> # Initializing a Zamba2-2.7B style configuration
>>> configuration = Zamba2Config()
>>> # Initializing a model from the Zamba2-2.7B style configuration
>>> model = Zamba2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""

model_type = "zamba2"
Expand All @@ -131,7 +132,6 @@ def __init__(
self,
vocab_size=32000,
max_position_embeddings=4096,
tie_word_embeddings=True,
hidden_size=2560,
num_hidden_layers=54,
layers_block_type=None,
Expand All @@ -143,10 +143,7 @@ def __init__(
time_step_max=0.1,
time_step_floor=1e-4,
time_step_limit=None,
mamba_dt_rank="auto",
n_mamba_heads=1,
mamba_proj_bias=False,
hidden_mamba_act="silu",
use_conv_bias=True,
chunk_size=256,
add_bias_linear=False,
Expand Down Expand Up @@ -175,13 +172,10 @@ def __init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
if intermediate_size is None:
self.intermediate_size = 4 * hidden_size
Expand All @@ -199,17 +193,13 @@ def __init__(
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
self.add_bias_linear = add_bias_linear
self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
self.mamba_ngroups = mamba_ngroups
self.n_mamba_heads = n_mamba_heads
self.mamba_proj_bias = mamba_proj_bias
self.hidden_mamba_act = hidden_mamba_act
self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
self.use_conv_bias = use_conv_bias
self.chunk_size = chunk_size
self.time_step_limit = time_step_limit

self.use_shared_mlp_lora = use_shared_mlp_lora
self.use_shared_attention_lora = use_shared_attention_lora
self.lora_rank = lora_rank
Expand All @@ -219,21 +209,12 @@ def __init__(
self.time_step_floor = time_step_floor
if use_long_context:
self.max_position_embeddings = 16384

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads

self.num_attention_heads = num_attention_heads
self.kv_channels = self.hidden_size // self.num_attention_heads
self.num_query_groups = self.num_attention_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps

self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep

# Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
if layers_block_type is None:
self.layers_block_type = (
Expand All @@ -246,4 +227,8 @@ def __init__(
+ ["mamba"] * 2
)
else:
self.layers_block_type = layers_block_type
self.layers_block_type = layers_block_type
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
Loading

0 comments on commit 987bba9

Please sign in to comment.