diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 936dcf82e57945..fd51375c349075 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1346,10 +1346,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ): model.config.problem_type = "single_label_classification" - if "past_key_values" in input_names_to_trace: - model.config.use_cache = True - else: - model.config.use_cache = False + model.config.use_cache = "past_key_values" in input_names_to_trace traced_model = symbolic_trace(model, input_names_to_trace)