Skip to content

Commit

Permalink
RmsNorm has an optional scale parameter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677183671
  • Loading branch information
alankelly authored and copybara-github committed Sep 21, 2024
1 parent 6082bf7 commit da988ae
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <cstdlib>
#include <fstream>
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>
#include <optional>
Expand Down Expand Up @@ -644,10 +643,14 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::RmsNorm(
// div_out = input / rms
MP_ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms));

// div_out * (1 + scale) = div_out + div_out * scale
MP_ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale));
if (scale) {
// div_out * (1 + scale) = div_out + div_out * scale
MP_ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale));

return ElementAdd(div_out, normed_div_out);
return ElementAdd(div_out, normed_div_out);
} else {
return div_out;
}
}

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,11 @@ class XnnGraphBuilder {

absl::StatusOr<std::shared_ptr<Tensor>> Rms(std::shared_ptr<Tensor> input);

// Root Mean Square normalization
// out = input / rms(input) * (1 + scale)
// if scale is absent, scale is considered to be zero.
absl::StatusOr<std::shared_ptr<Tensor>> RmsNorm(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale);
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale = nullptr);

absl::StatusOr<std::shared_ptr<Tensor>> Reshape(std::shared_ptr<Tensor> input,
Tensor::DimsType new_dims);
Expand Down

0 comments on commit da988ae

Please sign in to comment.