diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index dec1e8189..e1f6aa19b 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -44,7 +44,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 16 + self.block_size = 64 self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 7ece93a9d..28e6720ee 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -614,22 +614,6 @@ def forward( if past_len == 0: # prefill, remove padding seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - # varlen_attention( - # query.contiguous() if query.device.type == "xpu" else query, - # key.contiguous() if key.device.type == "xpu" else key, - # value.contiguous() if value.device.type == "xpu" else value, - # attn_output, - # seq_len_tensor, - # seq_len_tensor, - # input_lens.max(), - # input_lens.max(), - # 0.0, - # 1.0 / math.sqrt(self.head_dim), - # False, - # True, - # False, - # None, - # ) PagedAttention.flash_attn_varlen_func( attn_output, query, @@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) + _setattr_from_module(self, module) + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs): def postprocess_attention_output(self, attn_output): attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) return attn_output