diff --git a/WORKSPACE b/WORKSPACE index 284da91005..7b29e56f17 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -94,10 +94,10 @@ http_archive( http_archive( name = "cpuinfo", - sha256 = "ea028ced757dbc3309518ae7038ed625b02d58190078a5801d30e7b28f8b9e9c", - strip_prefix = "cpuinfo-ca678952a9a8eaa6de112d154e8e104b22f9ab3f", + sha256 = "2bf2b62eb86e2d2eaf862d0b9683a6c467a4d69fb2f7f1dc47c799809148608f", + strip_prefix = "cpuinfo-fa1c679da8d19e1d87f20175ae1ec10995cd3dd3", urls = [ - "https://github.com/pytorch/cpuinfo/archive/ca678952a9a8eaa6de112d154e8e104b22f9ab3f.zip" + "https://github.com/pytorch/cpuinfo/archive/fa1c679da8d19e1d87f20175ae1ec10995cd3dd3.zip", ], ) @@ -115,9 +115,9 @@ http_archive( http_archive( name = "XNNPACK", # `curl -L | shasum -a 256` - sha256 = "0e5d5c16686beff813e3946b26ca412f28acaf611228d20728ffb6479264fe19", - strip_prefix = "XNNPACK-9ddeb74f9f6866174d61888947e4aa9ffe963b1b", - url = "https://github.com/google/XNNPACK/archive/9ddeb74f9f6866174d61888947e4aa9ffe963b1b.zip", + sha256 = "08489dff917a8009bf2187995fc8e0a33a2207eef466e400302bbf3ef40e4811", + strip_prefix = "XNNPACK-3014fb625c73f3b1ce1f6d3e45f1e216f9cb7105", + url = "https://github.com/google/XNNPACK/archive/3014fb625c73f3b1ce1f6d3e45f1e216f9cb7105.zip", ) # TODO: This is an are indirect depedency. We should factor it out. diff --git a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc index 6cb9721877..71f8737ebe 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc +++ b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc @@ -249,6 +249,29 @@ absl::StatusOr> XnnGraphBuilder::IntermediateTensor( return t; } +absl::StatusOr> XnnGraphBuilder::ExpandDims( + std::shared_ptr input, Tensor::DimsType new_axes) { + Tensor::DimsType output_dims = input->dims; + + // Compute output shape. + for (size_t dim_idx = 0; dim_idx < new_axes.size(); ++dim_idx) { + output_dims.insert(output_dims.begin() + new_axes[dim_idx], 1); + } + + MP_ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(output_dims), + "expand_dims_output")); + build_steps_.push_back( + [input, output, new_axes](xnn_subgraph_t subgraph) -> absl::Status { + RET_CHECK_EQ(xnn_status_success, + xnn_define_static_expand_dims( + subgraph, new_axes.size(), new_axes.data(), + input->tensor_id(subgraph), + output->tensor_id(subgraph), /*flags=*/0)); + return absl::OkStatus(); + }); + return output; +} + absl::StatusOr> XnnGraphBuilder::Reshape( std::shared_ptr input, Tensor::DimsType new_dims) { size_t output_axis_dynamic = new_dims.size(); @@ -837,31 +860,67 @@ absl::StatusOr> XnnGraphBuilder::BatchMatMul( const auto& rhs_dim = weight->dims; // [B, N, T, S] . [B, N', H, S] - RET_CHECK_EQ(lhs_dim.size(), 4); - RET_CHECK_EQ(rhs_dim.size(), 4); + RET_CHECK_GE(lhs_dim.size(), 3); + RET_CHECK_GE(rhs_dim.size(), 3); uint32_t flags = 0; - const size_t N = std::max(lhs_dim[1], rhs_dim[1]); - const size_t T = lhs_dim[2]; + const size_t N = + std::max(lhs_dim[lhs_dim.size() - 3], rhs_dim[rhs_dim.size() - 3]); + const size_t T = lhs_dim[lhs_dim.size() - 2]; size_t H; - if (params.transpose) { + if (!params.transpose) { RET_CHECK_EQ(lhs_dim.back(), rhs_dim.back()); flags = XNN_FLAG_TRANSPOSE_B; - H = rhs_dim[2]; + H = rhs_dim[rhs_dim.size() - 2]; } else { RET_CHECK_EQ(lhs_dim.back(), rhs_dim[rhs_dim.size() - 2]); - H = rhs_dim[3]; + H = rhs_dim[rhs_dim.size() - 1]; } + size_t batch_size = lhs_dim.size() == 3 ? 1 : lhs_dim[0]; NewWeight(weight); - MP_ASSIGN_OR_RETURN(auto output, IntermediateTensor({lhs_dim[0], N, T, H}, - "batch_mat_mul_output")); + std::vector dims(std::max(lhs_dim.size(), rhs_dim.size())); + dims[dims.size() - 1] = H; + dims[dims.size() - 2] = T; + dims[dims.size() - 3] = N; + if (dims.size() > 3) { + dims[0] = batch_size; + } + MP_ASSIGN_OR_RETURN(auto output, + IntermediateTensor(dims, "batch_mat_mul_output")); - build_steps_.push_back([input, output, weight, - flags](xnn_subgraph_t subgraph) -> absl::Status { - RET_CHECK_EQ(xnn_status_success, xnn_define_batch_matrix_multiply( - subgraph, input->tensor_id(subgraph), - weight->tensor_id(subgraph), - output->tensor_id(subgraph), flags)); + std::shared_ptr qd_input; + bool use_dynamic_quantization = false; + if (runtime_configs_->use_dynamic_quantization.has_value()) { + use_dynamic_quantization = + runtime_configs_->use_dynamic_quantization.value(); + } else if (weight->datatype == xnn_datatype_qcint8 || + weight->datatype == xnn_datatype_qcint4) { + use_dynamic_quantization = true; + } + VLOG(3) << "use_dynamic_quantization: " << use_dynamic_quantization; + if (use_dynamic_quantization) { + MP_ASSIGN_OR_RETURN( + qd_input, IntermediateTensor({input->dims.begin(), input->dims.end()}, + xnn_datatype_qdint8)); + } + build_steps_.push_back([input, output, weight, flags, + qd_input](xnn_subgraph_t subgraph) -> absl::Status { + if (qd_input) { + RET_CHECK_EQ( + xnn_status_success, + xnn_define_convert(subgraph, input->tensor_id(subgraph), + qd_input->tensor_id(subgraph), /*flags=*/0)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_batch_matrix_multiply( + subgraph, qd_input->tensor_id(subgraph), + weight->tensor_id(subgraph), output->tensor_id(subgraph), flags)); + } else { + RET_CHECK_EQ(xnn_status_success, xnn_define_batch_matrix_multiply( + subgraph, input->tensor_id(subgraph), + weight->tensor_id(subgraph), + output->tensor_id(subgraph), flags)); + } return absl::OkStatus(); }); @@ -978,7 +1037,7 @@ absl::StatusOr> XnnGraphBuilder::QKVAttention( Tensor::DimsType reshape_hint) { RET_CHECK_EQ(query->dims.size(), 4); RET_CHECK_EQ(key_or_value->dims.size(), 4); - FullConnParams params{.transpose = true}; + FullConnParams params{.transpose = false}; return BatchMatMul(query, key_or_value, params); } diff --git a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h index 14d34eb93e..efc60807f3 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h +++ b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h @@ -303,6 +303,9 @@ class XnnGraphBuilder { std::shared_ptr beta = nullptr); protected: + absl::StatusOr> ExpandDims( + std::shared_ptr input, Tensor::DimsType new_axes); + absl::StatusOr> IntermediateTensor( Tensor::DimsType dims, absl::string_view tag = ""); absl::StatusOr> IntermediateTensor( diff --git a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.cc b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.cc index ef4906205a..e0233b9198 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.cc +++ b/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.cc @@ -16,6 +16,7 @@ #include +#include #include #include #include @@ -329,11 +330,14 @@ absl::Status Tensor::LoadFromBuffer(const void* buffer) { absl::Status Tensor::LoadFromVec(const std::vector& data, bool exact_match) { AllocateBufferIfNeeded(); + size_t load_size = data.size(); if (exact_match) { RET_CHECK_EQ(ElementSize(num_elements), data.size() * sizeof(float)); + } else { + load_size = std::min(data.size(), num_elements); } - memcpy(Data(), data.data(), data.size() * sizeof(float)); + memcpy(Data(), data.data(), load_size * sizeof(float)); return absl::OkStatus(); }