Skip to content

Commit

Permalink
Add HLFB to RMS Norm implementation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699245464
  • Loading branch information
talumbau authored and copybara-github committed Nov 22, 2024
1 parent b488a72 commit e3d0d3a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
1 change: 1 addition & 0 deletions ai_edge_torch/generative/layers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
dim,
eps=config.epsilon,
zero_centered_gamma=config.zero_centered,
enable_hlfb=config.enable_hlfb,
)
elif config.type == cfg.NormalizationType.LAYER_NORM:
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
Expand Down
56 changes: 52 additions & 4 deletions ai_edge_torch/generative/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,24 @@
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
class RMSNorm(torch.nn.Module):

def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False):
def __init__(
self,
dim: int,
eps: float = 1e-6,
zero_centered_gamma=False,
enable_hlfb: bool = False,
):
"""Initialize the RMSNorm layer.
Args:
dim (int): dimension of the input tensor.
eps (float): A small float value to ensure numerical stability (default:
1e-6).
zero_centered_gamma (bool): Whether or not gamma has an offset.
enable_hlfb (bool): use HLFB in the op.
"""
super().__init__()
self.enable_hlfb = enable_hlfb
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
self.zero_centered_gamma = zero_centered_gamma
Expand All @@ -56,12 +65,20 @@ def forward(self, x):
Returns:
torch.Tensor: output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if self.zero_centered_gamma:
return output * (1 + self.weight)
w = 1 + self.weight
else:
return output * self.weight
w = self.weight

if self.enable_hlfb:
return rms_norm_with_hlfb(
x,
w,
self.eps,
)
else:
output = self._norm(x.float()).type_as(x)
return output * w

class GroupNorm(torch.nn.Module):

Expand Down Expand Up @@ -194,6 +211,37 @@ def group_norm_with_hlfb(
return y


def rms_norm_with_hlfb(
x: torch.Tensor,
w: torch.Tensor,
eps: float,
):
"""RMS Normalization with high-level function boundary enabled.
Args:
x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
w (torch.Tensor): The learned parameter tensor for normalization.
eps (float): A small float value to ensure numerical stability.
Returns:
The output tensor of RMS Normalization.
"""
builder = StableHLOCompositeBuilder(
name="odml.rms_norm", attr={"epsilon": eps}
)

x, w = builder.mark_inputs(x, w)

def _norm(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

output = _norm(x.float()).type_as(x)
out = output * w

out = builder.mark_outputs(out)
return out


def layer_norm_with_hlfb(
x: torch.Tensor,
normalized_shape: list[int],
Expand Down

0 comments on commit e3d0d3a

Please sign in to comment.