diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 16b7c644d3633a..8da049f8518eba 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1801,7 +1801,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None,