Skip to content

Commit

Permalink
Fix missing group norm in converted TF Lite models. (#102)
Browse files Browse the repository at this point in the history
* Fix missing group norm in converted TF Lite models.

* update
  • Loading branch information
yichunk authored Jul 24, 2024
1 parent 53fa236 commit b178852
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
residual = input_tensor
B, C, H, W = input_tensor.shape
x = input_tensor
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
x = self.norm(x)
x = input_tensor.view(B, C, H * W)
x = self.norm(input_tensor)
x = x.view(B, C, H * W)
x = x.transpose(-1, -2)
else:
x = input_tensor.view(B, C, H * W)
Expand Down Expand Up @@ -181,10 +180,9 @@ def forward(
"""
residual = input_tensor
B, C, H, W = input_tensor.shape
x = input_tensor
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
x = self.norm(x)
x = input_tensor.view(B, C, H * W)
x = self.norm(input_tensor)
x = x.view(B, C, H * W)
x = x.transpose(-1, -2)
else:
x = input_tensor.view(B, C, H * W)
Expand Down Expand Up @@ -222,10 +220,9 @@ def __init__(
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
residual = input_tensor
B, C, H, W = input_tensor.shape
x = input_tensor
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
x = self.norm(x)
x = input_tensor.view(B, C, H * W)
x = self.norm(input_tensor)
x = x.view(B, C, H * W)
x = x.transpose(-1, -2)
else:
x = input_tensor.view(B, C, H * W)
Expand Down

0 comments on commit b178852

Please sign in to comment.