Skip to content

Commit

Permalink
Add new fusion argument from onnxruntime v1.16.2
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 13, 2023
1 parent 3916aa0 commit 41e9c20
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions optimum/onnxruntime/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ class OptimizationConfig:
Do not fuse GroupNorm. Only works for model_type=unet.
disable_packed_kv (`bool`, defaults to `True`):
Do not use packed kv in cross attention. Only works for model_type=unet.
disable_rotary_embeddings (`bool`, defaults to `False`):
Whether to disable Rotary Embedding fusion.
"""

optimization_level: int = 1
Expand Down Expand Up @@ -752,6 +754,9 @@ class OptimizationConfig:
disable_group_norm_fusion: bool = True
disable_packed_kv: bool = True

# ONNX Runtime 1.16.2 arguments
disable_rotary_embeddings: bool = False

def __post_init__(self):
def deprecate_renamed_attribute(old_name, new_name, mapping_func=None):
if getattr(self, old_name, None) is not None:
Expand Down Expand Up @@ -801,6 +806,7 @@ class Box:
"use_raw_attention_mask": "use_raw_attention_mask",
"enable_gemm_fast_gelu_fusion": "enable_gemm_fast_gelu",
"use_multi_head_attention": "use_multi_head_attention",
"disable_rotary_embeddings": "disable_rotary_embeddings",
}
for attr_name, fusion_attr_name in attribute_map.items():
setattr(args, fusion_attr_name, getattr(self, attr_name))
Expand Down

0 comments on commit 41e9c20

Please sign in to comment.