Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702346276
  • Loading branch information
alankelly authored and copybara-github committed Dec 3, 2024
1 parent b746d22 commit c1433ba
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1187,33 +1187,17 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Clamp(

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Gelu(
std::shared_ptr<Tensor> input) {
// x^2
MP_ASSIGN_OR_RETURN(auto sqr_out, Square(input));

// 0.044715 * x^2
MP_ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715));

// 1 + 0.044715 * x^2
MP_ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f));

// x + 0.044715 * x^3
MP_ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input));

constexpr float sqrt_2_over_pi = 0.7978845608;
MP_ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471,
ElementMul(x_cube_4471, sqrt_2_over_pi));

// tanh(x + 0.044715 * x^3)
MP_ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471));

// 1 + tanh(x + 0.044715 * x^3)
MP_ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1,
ElementAdd(tanh_x_cube_4471, 1.0f));

// 0.5 * (1 + [tanh(x + 0.044715 * x^3)])
MP_ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5));

return ElementMul(input, cdf);
MP_ASSIGN_OR_RETURN(auto output,
IntermediateTensor(input->dims, "gelu_output"));
build_steps_.push_back(
[output, input](xnn_subgraph_t subgraph) -> absl::Status {
RET_CHECK_EQ(xnn_status_success,
xnn_define_gelu(subgraph, input->tensor_id(subgraph),
output->tensor_id(subgraph),
/*flags=*/0));
return absl::Status();
});
return output;
}

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Sigmoid(
Expand Down

0 comments on commit c1433ba

Please sign in to comment.