Skip to content

Commit

Permalink
Align GroupNorm's epsilon name with RFC.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678309043
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 24, 2024
1 parent c9973d2 commit 1a2df0c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def group_norm_with_hlfb(
name="odml.group_norm",
attr={
"num_groups": num_groups,
"eps": eps,
"epsilon": eps,
"reduction_axes": 3,
"channel_axis": 3,
},
Expand Down Expand Up @@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
"""
builder = StableHLOCompositeBuilder(
name="odml.group_norm",
attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
)
x, w, b = builder.mark_inputs(x, w, b)
if use_input_shape:
Expand Down

0 comments on commit 1a2df0c

Please sign in to comment.