Skip to content

Commit

Permalink
optimize gpt2 by using linear instead of conv1D
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 12, 2024
1 parent 6d21075 commit b792875
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.block_size = 64
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
Expand Down
26 changes: 8 additions & 18 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,22 +614,6 @@ def forward(
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
# varlen_attention(
# query.contiguous() if query.device.type == "xpu" else query,
# key.contiguous() if key.device.type == "xpu" else key,
# value.contiguous() if value.device.type == "xpu" else value,
# attn_output,
# seq_len_tensor,
# seq_len_tensor,
# input_lens.max(),
# input_lens.max(),
# 0.0,
# 1.0 / math.sqrt(self.head_dim),
# False,
# True,
# False,
# None,
# )
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
Expand Down Expand Up @@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
_setattr_from_module(self, module)
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
Expand All @@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs):
def postprocess_attention_output(self, attn_output):
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output


Expand Down

0 comments on commit b792875

Please sign in to comment.