diff --git a/ai_edge_torch/generative/layers/builder.py b/ai_edge_torch/generative/layers/builder.py index b55aa662..6cc73291 100644 --- a/ai_edge_torch/generative/layers/builder.py +++ b/ai_edge_torch/generative/layers/builder.py @@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig): dim, eps=config.epsilon, zero_centered_gamma=config.zero_centered, + enable_hlfb=config.enable_hlfb, ) elif config.type == cfg.NormalizationType.LAYER_NORM: return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb) diff --git a/ai_edge_torch/generative/layers/normalization.py b/ai_edge_torch/generative/layers/normalization.py index e250e7be..275088c6 100644 --- a/ai_edge_torch/generative/layers/normalization.py +++ b/ai_edge_torch/generative/layers/normalization.py @@ -23,15 +23,24 @@ # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467 class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False): + def __init__( + self, + dim: int, + eps: float = 1e-6, + zero_centered_gamma=False, + enable_hlfb: bool = False, + ): """Initialize the RMSNorm layer. Args: dim (int): dimension of the input tensor. eps (float): A small float value to ensure numerical stability (default: 1e-6). + zero_centered_gamma (bool): Whether or not gamma has an offset. + enable_hlfb (bool): use HLFB in the op. """ super().__init__() + self.enable_hlfb = enable_hlfb self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) self.zero_centered_gamma = zero_centered_gamma @@ -56,12 +65,20 @@ def forward(self, x): Returns: torch.Tensor: output tensor after applying RMSNorm. """ - output = self._norm(x.float()).type_as(x) if self.zero_centered_gamma: - return output * (1 + self.weight) + w = 1 + self.weight else: - return output * self.weight + w = self.weight + if self.enable_hlfb: + return rms_norm_with_hlfb( + x, + w, + self.eps, + ) + else: + output = self._norm(x.float()).type_as(x) + return output * w class GroupNorm(torch.nn.Module): @@ -194,6 +211,37 @@ def group_norm_with_hlfb( return y +def rms_norm_with_hlfb( + x: torch.Tensor, + w: torch.Tensor, + eps: float, +): + """RMS Normalization with high-level function boundary enabled. + + Args: + x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape. + w (torch.Tensor): The learned parameter tensor for normalization. + eps (float): A small float value to ensure numerical stability. + + Returns: + The output tensor of RMS Normalization. + """ + builder = StableHLOCompositeBuilder( + name="odml.rms_norm", attr={"epsilon": eps} + ) + + x, w = builder.mark_inputs(x, w) + + def _norm(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + output = _norm(x.float()).type_as(x) + out = output * w + + out = builder.mark_outputs(out) + return out + + def layer_norm_with_hlfb( x: torch.Tensor, normalized_shape: list[int],