diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 28e6720ee..fdc7ea86b 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -662,9 +662,9 @@ def __init__(self, module, config) -> None: if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() self.concat_linear.bias = nn.Parameter(concat_bias) - self.q_slice = self.q_proj.out_features - self.k_slice = self.q_slice + self.k_proj.out_features - self.v_slice = self.k_slice + self.v_proj.out_features + self.q_slice = self.q_proj.weight.shape[0] + self.k_slice = self.q_slice + self.k_proj.weight.shape[0] + self.v_slice = self.k_slice + self.v_proj.weight.shape[0] if self.module_device.type == "cpu": if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj)