diff --git a/optimum/onnxruntime/configuration.py b/optimum/onnxruntime/configuration.py index db0774cb27c..c11cf58b8b0 100644 --- a/optimum/onnxruntime/configuration.py +++ b/optimum/onnxruntime/configuration.py @@ -708,6 +708,8 @@ class OptimizationConfig: Do not fuse GroupNorm. Only works for model_type=unet. disable_packed_kv (`bool`, defaults to `True`): Do not use packed kv in cross attention. Only works for model_type=unet. + disable_rotary_embeddings (`bool`, defaults to `False`): + Whether to disable Rotary Embedding fusion. """ optimization_level: int = 1 @@ -752,6 +754,9 @@ class OptimizationConfig: disable_group_norm_fusion: bool = True disable_packed_kv: bool = True + # ONNX Runtime 1.16.2 arguments + disable_rotary_embeddings: bool = False + def __post_init__(self): def deprecate_renamed_attribute(old_name, new_name, mapping_func=None): if getattr(self, old_name, None) is not None: @@ -801,6 +806,7 @@ class Box: "use_raw_attention_mask": "use_raw_attention_mask", "enable_gemm_fast_gelu_fusion": "enable_gemm_fast_gelu", "use_multi_head_attention": "use_multi_head_attention", + "disable_rotary_embeddings": "disable_rotary_embeddings", } for attr_name, fusion_attr_name in attribute_map.items(): setattr(args, fusion_attr_name, getattr(self, attr_name))