diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index f8e9fdb77b20fa..ffe859bb59d62e 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -483,7 +483,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): return model_kwargs for model_class in decoder_only_classes: - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() signature = inspect.signature(model.forward).parameters.keys()