Skip to content

Commit

Permalink
[xla:cpu] Add missing files from #16438
Browse files Browse the repository at this point in the history
Add missing files and changes from a PR adding matmul reordering support to oneDNN for aarch64 CPU:
#16438

Also add a missing indirect convolution patch from a TF PR:
tensorflow/tensorflow#62852

PiperOrigin-RevId: 705268797
  • Loading branch information
penpornk authored and Google-ML-Automation committed Dec 11, 2024
1 parent 0060947 commit 4b68998
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp
index 65b887ea21..eabdb827bd 100644
--- a/src/cpu/platform.cpp
+++ b/src/cpu/platform.cpp
@@ -117,6 +117,8 @@ bool has_data_type_support(data_type_t data_type) {
#if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__)
return true;
#endif
+#elif DNNL_AARCH64_USE_ACL
+ return arm_compute::CPUInfo::get().has_bf16();
#else
return false;
#endif
--
2.34.1

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp
index ab13efb9b2..ec261e156d 100644
--- a/src/cpu/aarch64/matmul/acl_matmul.hpp
+++ b/src/cpu/aarch64/matmul/acl_matmul.hpp
@@ -78,11 +78,21 @@ struct acl_matmul_t : public primitive_t {
= utils::everyone_is(data_type::f16, src_md()->data_type,
weights_md()->data_type, dst_md()->data_type)
&& platform::has_data_type_support(data_type::f16);
+ const bool is_fp32_bf16_ok
+ = (utils::everyone_is(data_type::f32, src_md()->data_type,
+ dst_md()->data_type, desc()->accum_data_type)
+ && platform::has_data_type_support(data_type::f32)
+ && utils::everyone_is(
+ data_type::bf16, weights_md()->data_type)
+ && platform::has_data_type_support(
+ data_type::bf16));
+
const bool is_weights_md_format_ok
= utils::one_of(weights_format_kind_received,
format_kind::any, format_kind::blocked);
bool ok = is_dense_data()
- && utils::one_of(true, is_fp32_ok, is_fp16_ok)
+ && utils::one_of(
+ true, is_fp32_ok, is_fp16_ok, is_fp32_bf16_ok)
&& !has_zero_dim_memory() && is_weights_md_format_ok
&& set_default_formats()
&& attr()->has_default_values(
--
2.34.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp
index 451cc78d52..ab13efb9b2 100644
--- a/src/cpu/aarch64/matmul/acl_matmul.hpp
+++ b/src/cpu/aarch64/matmul/acl_matmul.hpp
@@ -67,6 +67,8 @@ struct acl_matmul_t : public primitive_t {

status_t init(engine_t *engine) {
using smask_t = primitive_attr_t::skip_mask_t;
+ const format_kind_t weights_format_kind_received
+ = weights_md_.format_kind;
const bool is_fp32_ok
= utils::everyone_is(data_type::f32, src_md()->data_type,
weights_md()->data_type, dst_md()->data_type,
@@ -76,18 +78,20 @@ struct acl_matmul_t : public primitive_t {
= utils::everyone_is(data_type::f16, src_md()->data_type,
weights_md()->data_type, dst_md()->data_type)
&& platform::has_data_type_support(data_type::f16);
+ const bool is_weights_md_format_ok
+ = utils::one_of(weights_format_kind_received,
+ format_kind::any, format_kind::blocked);
bool ok = is_dense_data()
&& utils::one_of(true, is_fp32_ok, is_fp16_ok)
- && !has_zero_dim_memory()
- && weights_md_.format_kind == format_kind::any
+ && !has_zero_dim_memory() && is_weights_md_format_ok
&& set_default_formats()
&& attr()->has_default_values(
smask_t::oscale | smask_t::post_ops)
&& attr_oscale_ok() && !has_runtime_dims_or_strides();
if (!ok) return status::unimplemented;

- CHECK(acl_matmul_utils::init_conf_matmul(
- amp_, src_md_, weights_md_, dst_md_, *desc(), *attr()));
+ CHECK(acl_matmul_utils::init_conf_matmul(amp_, src_md_, weights_md_,
+ dst_md_, *desc(), *attr(), weights_format_kind_received));

arm_compute::ActivationLayerInfo act_info;
CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, act_info));
diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp
index a314d96384..027f915a8a 100644
--- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp
+++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp
@@ -27,7 +27,8 @@ namespace acl_matmul_utils {

status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md,
- const primitive_attr_t &attr) {
+ const primitive_attr_t &attr,
+ format_kind_t weights_format_kind_received) {

const memory_desc_wrapper src_d(&src_md);
const memory_desc_wrapper wei_d(&wei_md);
@@ -128,9 +129,16 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
for (dim_t i = K_dim - 1; i >= 0; --i)
batch_dims.push_back(i);

+ const memory_desc_t weights_md_received = wei_md;
acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md,
expected_weight_format, K_dim, N_dim, {}, batch_dims);

+ ACL_CHECK_SUPPORT((weights_format_kind_received == format_kind::blocked)
+ && !(dnnl_memory_desc_equal(&weights_md_received, &wei_md)),
+ "specified blocked format not supported by ACL, use "
+ "format_kind_t::any to find a supported blocked format for "
+ "your platform");
+
return status::success;
}

diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp
index 67bb2e78eb..5ba4241abc 100644
--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp
+++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp
@@ -52,7 +52,8 @@ namespace acl_matmul_utils {

status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md,
- const primitive_attr_t &attr);
+ const primitive_attr_t &attr,
+ format_kind_t weights_format_kind_received);

} // namespace acl_matmul_utils

--
2.34.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/aarch64/acl_post_ops.cpp
index ea4bb200ec..3eb53b81bd 100644
--- a/src/cpu/aarch64/acl_post_ops.cpp
+++ b/src/cpu/aarch64/acl_post_ops.cpp
@@ -24,7 +24,7 @@ namespace aarch64 {

status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const {

- int post_op_index = 0;
+ int post_op_index = post_op_start_index_;

// As these are post ops, this src will also be our dst. If we have a sum
// post op, the src/dst will start off in a temporary, then change to
diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp
index 7b59ad71d3..ceaa95b73a 100644
--- a/src/cpu/aarch64/acl_post_ops.hpp
+++ b/src/cpu/aarch64/acl_post_ops.hpp
@@ -32,7 +32,9 @@ struct acl_post_ops_t {
// init the acl_post_ops_t. Note that this function modifies the passed in
// post ops by setting the preferred memory formats
status_t init(engine_t *engine, post_ops_t &post_ops,
- const memory_desc_t &dst_md) {
+ const memory_desc_t &dst_md, int post_op_start_index = 0) {
+
+ post_op_start_index_ = post_op_start_index;

CHECK(post_ops.set_default_formats(&dst_md));
dst_data_type = dst_md.data_type;
@@ -41,7 +43,7 @@ struct acl_post_ops_t {
sum_index = -1;
post_op_primitives = {};

- for (int i = 0; i < post_ops.len(); i++) {
+ for (int i = post_op_start_index; i < post_ops.len(); i++) {
auto &po = post_ops.entry_[i];

if (po.is_sum()) {
@@ -135,7 +137,8 @@ struct acl_post_ops_t {
// formats
status_t init(engine_t *engine, post_ops_t &base_post_ops,
const memory_desc_t &dst_md,
- arm_compute::ActivationLayerInfo &act_info_to_fuse) {
+ arm_compute::ActivationLayerInfo &act_info_to_fuse,
+ int post_op_start_index = 0) {

CHECK(base_post_ops.set_default_formats(&dst_md));
dst_data_type = dst_md.data_type;
@@ -149,18 +152,11 @@ struct acl_post_ops_t {
"eltwise post op scale must be 1 (no scale)");
CHECK(acl_utils::convert_to_acl_act(first_po, act_info_to_fuse));

- // Copy all but the first, because it has been fused
- post_ops_t post_ops;
- for (int idx = 1; idx < base_post_ops.len(); ++idx) {
- // Construct empty entry then copy, so that we can check for failure
- post_ops.entry_.emplace_back();
- post_ops.entry_.back().copy_from(base_post_ops.entry_[idx]);
- }
- return init(engine, post_ops, dst_md);
-
+ // post_op_start_index + 1 to skip the fused eltwise
+ return init(engine, base_post_ops, dst_md, post_op_start_index + 1);
} else {
// Nothing to fuse, just copy all post ops
- return init(engine, base_post_ops, dst_md);
+ return init(engine, base_post_ops, dst_md, post_op_start_index);
}
}

@@ -179,6 +175,9 @@ struct acl_post_ops_t {
private:
// Index of the sum post op if there is one, < 0 means no sum
int sum_index = -1;
+ // Index of the first post op this primitive executes. This is typically the
+ // number of post ops which were fused.
+ int post_op_start_index_ = 0;
data_type_t dst_data_type;
// Vector of primitives used to execute the post ops. They are constructed
// in init to be either acl_binary_t (for sum, add, sub, div, mul, min and
--
2.34.1
5 changes: 5 additions & 0 deletions third_party/tsl/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def _tf_repositories():
"//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch",
"//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch",
"//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch",
"//third_party/mkl_dnn:onednn_acl_indirect_conv.patch",
"//third_party/mkl_dnn:onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch",
"//third_party/mkl_dnn:onednn_acl_fix_segfault_during_postop_execute.patch",
"//third_party/mkl_dnn:onednn_acl_add_bf16_platform_support_check.patch",
"//third_party/mkl_dnn:onednn_acl_add_sbgemm_matmul_primitive_definition.patch",
],
sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3",
strip_prefix = "oneDNN-3.2.1",
Expand Down

0 comments on commit 4b68998

Please sign in to comment.