From c9f094a77b5c487fbe800388bbdd0a463c472453 Mon Sep 17 00:00:00 2001 From: Tomer Asida Date: Wed, 17 Apr 2024 00:12:14 +0300 Subject: [PATCH] style fixes --- .../models/jamba/modeling_jamba.py | 18 +++++++++--------- tests/models/jamba/test_modeling_jamba.py | 2 -- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 15dc05cd55a4e6..20f6820dc94d9b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -20,7 +20,7 @@ """ PyTorch Jamba model.""" import inspect import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv +from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, @@ -257,11 +257,11 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Update the cache if self.key_cache[layer_idx].shape[-1] == 0: @@ -1302,12 +1302,12 @@ def _init_weights(self, module): [What are position IDs?](../glossary#position-ids) past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the - self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `[batch_size, num_heads, seq_len, head_dim]`. Convolution and ssm states tensors have shape `[batch_size, d_inner, d_conv]` and `[batch_size, d_inner, d_state]` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `HybridMambaAttentionDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 5f935bf3deefd4..2444fa5a594f4f 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -46,8 +46,6 @@ ) from transformers.models.jamba.modeling_jamba import ( HybridMambaAttentionDynamicCache, - JambaAttentionDecoderLayer, - JambaMambaDecoderLayer, )