Skip to content

Commit

Permalink
Reduce LayerNorm composite op to GroupNorm composite op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677874810
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 23, 2024
1 parent 197fb9f commit 0233ab9
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions ai_edge_torch/generative/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,16 @@ def group_norm_with_hlfb(
"""
x = torch.permute(x, (0, 2, 3, 1))

# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
# int32 when the bug is fixed.
builder = StableHLOCompositeBuilder(
name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
name="odml.group_norm",
attr={
"num_groups": num_groups,
"eps": eps,
"reduction_axes": 3,
"channel_axis": 3,
},
)
x, w, b = builder.mark_inputs(x, w, b)
x = torch.permute(x, (0, 3, 1, 2))
Expand All @@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
"""Layer Normalization with high-level function boundary enabled.
Args:
x (torch.Tensor): Input tensor for Layer Normalization.
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
w (torch.Tensor): The weight tensor for the normalization.
b (torch.Tensor): The bias tensor for the normalization.
eps (float): A small float value to ensure numerical stability.
Expand All @@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
Returns:
The output tensor of Layer Normalization.
"""
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
builder = StableHLOCompositeBuilder(
name="odml.group_norm",
attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
)
x, w, b = builder.mark_inputs(x, w, b)
if use_input_shape:
normalized_shape = x.shape
Expand Down

0 comments on commit 0233ab9

Please sign in to comment.