diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index 2456cb7a..4c434d48 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -189,20 +189,23 @@ def forward( # Assemble into a number of query groups to support MHA, MQA and GQA. q_per_kv = self.config.num_heads // self.config.num_query_groups - total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value. + # Each group has >=1 queries, 1 key, and 1 value. if self.config.qkv_transpose_before_split: - qkv = qkv.view( - B, T, total_qkv, self.config.num_query_groups, self.head_dim - ) # (B, T, total_qkv, num_query_groups, head_dim) - qkv_axis = -3 + qkv = qkv.view(B, T, -1, self.head_dim) + q, k, v = qkv.split( + ( + q_per_kv * self.config.num_query_groups, + self.config.num_query_groups, + self.config.num_query_groups, + ), + dim=-2, + ) else: - qkv = qkv.view( - B, T, self.config.num_query_groups, total_qkv, self.head_dim - ) # (B, T, num_query_groups, total_qkv, head_dim) - qkv_axis = -2 + qkv = qkv.view(B, T, self.config.num_query_groups, -1) + q, k, v = qkv.split( + (q_per_kv * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ) - # Split batched computation into three. - q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis) q = q.reshape(B, T, -1, self.head_dim) k = k.reshape(B, T, -1, self.head_dim) v = v.reshape(B, T, -1, self.head_dim) diff --git a/ai_edge_torch/generative/layers/unet/blocks_2d.py b/ai_edge_torch/generative/layers/unet/blocks_2d.py index ff038e11..863c9064 100644 --- a/ai_edge_torch/generative/layers/unet/blocks_2d.py +++ b/ai_edge_torch/generative/layers/unet/blocks_2d.py @@ -134,6 +134,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: x = input_tensor.view(B, C, H * W) x = x.transpose(-1, -2) x = self.norm(x) + x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite. x = self.attention(x) x = x.transpose(-1, -2) x = x.view(B, C, H, W)