Skip to content

Commit

Permalink
Remove 5D tensor reshape in attention layer implementation. (#57)
Browse files Browse the repository at this point in the history
* Remove 5D tensor reshape in attention layer implementation.

* formatting
  • Loading branch information
yichunk authored Jun 14, 2024
1 parent 5b31b82 commit 0318b3e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
25 changes: 14 additions & 11 deletions ai_edge_torch/generative/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,23 @@ def forward(

# Assemble into a number of query groups to support MHA, MQA and GQA.
q_per_kv = self.config.num_heads // self.config.num_query_groups
total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
# Each group has >=1 queries, 1 key, and 1 value.
if self.config.qkv_transpose_before_split:
qkv = qkv.view(
B, T, total_qkv, self.config.num_query_groups, self.head_dim
) # (B, T, total_qkv, num_query_groups, head_dim)
qkv_axis = -3
qkv = qkv.view(B, T, -1, self.head_dim)
q, k, v = qkv.split(
(
q_per_kv * self.config.num_query_groups,
self.config.num_query_groups,
self.config.num_query_groups,
),
dim=-2,
)
else:
qkv = qkv.view(
B, T, self.config.num_query_groups, total_qkv, self.head_dim
) # (B, T, num_query_groups, total_qkv, head_dim)
qkv_axis = -2
qkv = qkv.view(B, T, self.config.num_query_groups, -1)
q, k, v = qkv.split(
(q_per_kv * self.head_dim, self.head_dim, self.head_dim), dim=-1
)

# Split batched computation into three.
q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
q = q.reshape(B, T, -1, self.head_dim)
k = k.reshape(B, T, -1, self.head_dim)
v = v.reshape(B, T, -1, self.head_dim)
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = input_tensor.view(B, C, H * W)
x = x.transpose(-1, -2)
x = self.norm(x)
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
x = self.attention(x)
x = x.transpose(-1, -2)
x = x.view(B, C, H, W)
Expand Down

0 comments on commit 0318b3e

Please sign in to comment.