diff --git a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc index 71f8737ebe..384b1d9dcd 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc +++ b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -644,10 +643,14 @@ absl::StatusOr> XnnGraphBuilder::RmsNorm( // div_out = input / rms MP_ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms)); - // div_out * (1 + scale) = div_out + div_out * scale - MP_ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale)); + if (scale) { + // div_out * (1 + scale) = div_out + div_out * scale + MP_ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale)); - return ElementAdd(div_out, normed_div_out); + return ElementAdd(div_out, normed_div_out); + } else { + return div_out; + } } absl::StatusOr> XnnGraphBuilder::ElementAdd( diff --git a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h index efc60807f3..a7f1e0950f 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h +++ b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h @@ -160,8 +160,11 @@ class XnnGraphBuilder { absl::StatusOr> Rms(std::shared_ptr input); + // Root Mean Square normalization + // out = input / rms(input) * (1 + scale) + // if scale is absent, scale is considered to be zero. absl::StatusOr> RmsNorm( - std::shared_ptr input, std::shared_ptr scale); + std::shared_ptr input, std::shared_ptr scale = nullptr); absl::StatusOr> Reshape(std::shared_ptr input, Tensor::DimsType new_dims);