Skip to content

Commit

Permalink
style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 committed Apr 16, 2024
1 parent a252fe0 commit c9f094a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
""" PyTorch Jamba model."""
import inspect
import math
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
Expand Down Expand Up @@ -257,11 +257,11 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the cache
if self.key_cache[layer_idx].shape[-1] == 0:
Expand Down Expand Up @@ -1302,12 +1302,12 @@ def _init_weights(self, module):
[What are position IDs?](../glossary#position-ids)
past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `[batch_size, num_heads, seq_len, head_dim]`.
Convolution and ssm states tensors have shape `[batch_size, d_inner, d_conv]` and
`[batch_size, d_inner, d_state]` respectively.
See the `HybridMambaAttentionDynamicCache` class for more details.
See the `HybridMambaAttentionDynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
Expand Down
2 changes: 0 additions & 2 deletions tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
)
from transformers.models.jamba.modeling_jamba import (
HybridMambaAttentionDynamicCache,
JambaAttentionDecoderLayer,
JambaMambaDecoderLayer,
)


Expand Down

0 comments on commit c9f094a

Please sign in to comment.