Skip to content

Commit

Permalink
BatchMatMul can handle cases where ndims != 4 and quantization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677146942
  • Loading branch information
alankelly authored and copybara-github committed Sep 21, 2024
1 parent 7090099 commit 6082bf7
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 23 deletions.
12 changes: 6 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ http_archive(

http_archive(
name = "cpuinfo",
sha256 = "ea028ced757dbc3309518ae7038ed625b02d58190078a5801d30e7b28f8b9e9c",
strip_prefix = "cpuinfo-ca678952a9a8eaa6de112d154e8e104b22f9ab3f",
sha256 = "2bf2b62eb86e2d2eaf862d0b9683a6c467a4d69fb2f7f1dc47c799809148608f",
strip_prefix = "cpuinfo-fa1c679da8d19e1d87f20175ae1ec10995cd3dd3",
urls = [
"https://github.com/pytorch/cpuinfo/archive/ca678952a9a8eaa6de112d154e8e104b22f9ab3f.zip"
"https://github.com/pytorch/cpuinfo/archive/fa1c679da8d19e1d87f20175ae1ec10995cd3dd3.zip",
],
)

Expand All @@ -115,9 +115,9 @@ http_archive(
http_archive(
name = "XNNPACK",
# `curl -L <url> | shasum -a 256`
sha256 = "0e5d5c16686beff813e3946b26ca412f28acaf611228d20728ffb6479264fe19",
strip_prefix = "XNNPACK-9ddeb74f9f6866174d61888947e4aa9ffe963b1b",
url = "https://github.com/google/XNNPACK/archive/9ddeb74f9f6866174d61888947e4aa9ffe963b1b.zip",
sha256 = "08489dff917a8009bf2187995fc8e0a33a2207eef466e400302bbf3ef40e4811",
strip_prefix = "XNNPACK-3014fb625c73f3b1ce1f6d3e45f1e216f9cb7105",
url = "https://github.com/google/XNNPACK/archive/3014fb625c73f3b1ce1f6d3e45f1e216f9cb7105.zip",
)

# TODO: This is an are indirect depedency. We should factor it out.
Expand Down
91 changes: 75 additions & 16 deletions mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,29 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::IntermediateTensor(
return t;
}

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ExpandDims(
std::shared_ptr<Tensor> input, Tensor::DimsType new_axes) {
Tensor::DimsType output_dims = input->dims;

// Compute output shape.
for (size_t dim_idx = 0; dim_idx < new_axes.size(); ++dim_idx) {
output_dims.insert(output_dims.begin() + new_axes[dim_idx], 1);
}

MP_ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(output_dims),
"expand_dims_output"));
build_steps_.push_back(
[input, output, new_axes](xnn_subgraph_t subgraph) -> absl::Status {
RET_CHECK_EQ(xnn_status_success,
xnn_define_static_expand_dims(
subgraph, new_axes.size(), new_axes.data(),
input->tensor_id(subgraph),
output->tensor_id(subgraph), /*flags=*/0));
return absl::OkStatus();
});
return output;
}

absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Reshape(
std::shared_ptr<Tensor> input, Tensor::DimsType new_dims) {
size_t output_axis_dynamic = new_dims.size();
Expand Down Expand Up @@ -837,31 +860,67 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::BatchMatMul(
const auto& rhs_dim = weight->dims;

// [B, N, T, S] . [B, N', H, S]
RET_CHECK_EQ(lhs_dim.size(), 4);
RET_CHECK_EQ(rhs_dim.size(), 4);
RET_CHECK_GE(lhs_dim.size(), 3);
RET_CHECK_GE(rhs_dim.size(), 3);
uint32_t flags = 0;
const size_t N = std::max(lhs_dim[1], rhs_dim[1]);
const size_t T = lhs_dim[2];
const size_t N =
std::max(lhs_dim[lhs_dim.size() - 3], rhs_dim[rhs_dim.size() - 3]);
const size_t T = lhs_dim[lhs_dim.size() - 2];
size_t H;
if (params.transpose) {
if (!params.transpose) {
RET_CHECK_EQ(lhs_dim.back(), rhs_dim.back());
flags = XNN_FLAG_TRANSPOSE_B;
H = rhs_dim[2];
H = rhs_dim[rhs_dim.size() - 2];
} else {
RET_CHECK_EQ(lhs_dim.back(), rhs_dim[rhs_dim.size() - 2]);
H = rhs_dim[3];
H = rhs_dim[rhs_dim.size() - 1];
}

size_t batch_size = lhs_dim.size() == 3 ? 1 : lhs_dim[0];
NewWeight(weight);
MP_ASSIGN_OR_RETURN(auto output, IntermediateTensor({lhs_dim[0], N, T, H},
"batch_mat_mul_output"));
std::vector<size_t> dims(std::max(lhs_dim.size(), rhs_dim.size()));
dims[dims.size() - 1] = H;
dims[dims.size() - 2] = T;
dims[dims.size() - 3] = N;
if (dims.size() > 3) {
dims[0] = batch_size;
}
MP_ASSIGN_OR_RETURN(auto output,
IntermediateTensor(dims, "batch_mat_mul_output"));

build_steps_.push_back([input, output, weight,
flags](xnn_subgraph_t subgraph) -> absl::Status {
RET_CHECK_EQ(xnn_status_success, xnn_define_batch_matrix_multiply(
subgraph, input->tensor_id(subgraph),
weight->tensor_id(subgraph),
output->tensor_id(subgraph), flags));
std::shared_ptr<Tensor> qd_input;
bool use_dynamic_quantization = false;
if (runtime_configs_->use_dynamic_quantization.has_value()) {
use_dynamic_quantization =
runtime_configs_->use_dynamic_quantization.value();
} else if (weight->datatype == xnn_datatype_qcint8 ||
weight->datatype == xnn_datatype_qcint4) {
use_dynamic_quantization = true;
}
VLOG(3) << "use_dynamic_quantization: " << use_dynamic_quantization;
if (use_dynamic_quantization) {
MP_ASSIGN_OR_RETURN(
qd_input, IntermediateTensor({input->dims.begin(), input->dims.end()},
xnn_datatype_qdint8));
}
build_steps_.push_back([input, output, weight, flags,
qd_input](xnn_subgraph_t subgraph) -> absl::Status {
if (qd_input) {
RET_CHECK_EQ(
xnn_status_success,
xnn_define_convert(subgraph, input->tensor_id(subgraph),
qd_input->tensor_id(subgraph), /*flags=*/0));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_batch_matrix_multiply(
subgraph, qd_input->tensor_id(subgraph),
weight->tensor_id(subgraph), output->tensor_id(subgraph), flags));
} else {
RET_CHECK_EQ(xnn_status_success, xnn_define_batch_matrix_multiply(
subgraph, input->tensor_id(subgraph),
weight->tensor_id(subgraph),
output->tensor_id(subgraph), flags));
}

return absl::OkStatus();
});
Expand Down Expand Up @@ -978,7 +1037,7 @@ absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::QKVAttention(
Tensor::DimsType reshape_hint) {
RET_CHECK_EQ(query->dims.size(), 4);
RET_CHECK_EQ(key_or_value->dims.size(), 4);
FullConnParams params{.transpose = true};
FullConnParams params{.transpose = false};
return BatchMatMul(query, key_or_value, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ class XnnGraphBuilder {
std::shared_ptr<Tensor> beta = nullptr);

protected:
absl::StatusOr<std::shared_ptr<Tensor>> ExpandDims(
std::shared_ptr<Tensor> input, Tensor::DimsType new_axes);

absl::StatusOr<std::shared_ptr<Tensor>> IntermediateTensor(
Tensor::DimsType dims, absl::string_view tag = "");
absl::StatusOr<std::shared_ptr<Tensor>> IntermediateTensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <fcntl.h>

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -329,11 +330,14 @@ absl::Status Tensor::LoadFromBuffer(const void* buffer) {
absl::Status Tensor::LoadFromVec(const std::vector<float>& data,
bool exact_match) {
AllocateBufferIfNeeded();
size_t load_size = data.size();
if (exact_match) {
RET_CHECK_EQ(ElementSize(num_elements), data.size() * sizeof(float));
} else {
load_size = std::min(data.size(), num_elements);
}

memcpy(Data(), data.data(), data.size() * sizeof(float));
memcpy(Data(), data.data(), load_size * sizeof(float));

return absl::OkStatus();
}
Expand Down

0 comments on commit 6082bf7

Please sign in to comment.