diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 53b1a0e94e7..6a2ff80b728 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -571,6 +571,9 @@ class ORTModelForConditionalGeneration(ORTModel, ABC): # Used in from_transformers to export model to onnxORTEncoder base_model_prefix = "onnx_model" + _supports_cache_class = False + _supports_static_cache = False + def __init__( self, encoder_session: ort.InferenceSession, @@ -1142,9 +1145,6 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" - _supports_cache_class = False - _supports_static_cache = False - def __init__( self, encoder_session: ort.InferenceSession,