-
Notifications
You must be signed in to change notification settings - Fork 452
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[xla:cpu] Add missing files from #16438
Add missing files and changes from a PR adding matmul reordering support to oneDNN for aarch64 CPU: #16438 Also add a missing indirect convolution patch from a TF PR: tensorflow/tensorflow#62852 PiperOrigin-RevId: 705268797
- Loading branch information
1 parent
0060947
commit 4b68998
Showing
5 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
31 changes: 31 additions & 0 deletions
31
third_party/tsl/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp | ||
index 65b887ea21..eabdb827bd 100644 | ||
--- a/src/cpu/platform.cpp | ||
+++ b/src/cpu/platform.cpp | ||
@@ -117,6 +117,8 @@ bool has_data_type_support(data_type_t data_type) { | ||
#if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) | ||
return true; | ||
#endif | ||
+#elif DNNL_AARCH64_USE_ACL | ||
+ return arm_compute::CPUInfo::get().has_bf16(); | ||
#else | ||
return false; | ||
#endif | ||
-- | ||
2.34.1 | ||
|
44 changes: 44 additions & 0 deletions
44
third_party/tsl/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
index ab13efb9b2..ec261e156d 100644 | ||
--- a/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
+++ b/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
@@ -78,11 +78,21 @@ struct acl_matmul_t : public primitive_t { | ||
= utils::everyone_is(data_type::f16, src_md()->data_type, | ||
weights_md()->data_type, dst_md()->data_type) | ||
&& platform::has_data_type_support(data_type::f16); | ||
+ const bool is_fp32_bf16_ok | ||
+ = (utils::everyone_is(data_type::f32, src_md()->data_type, | ||
+ dst_md()->data_type, desc()->accum_data_type) | ||
+ && platform::has_data_type_support(data_type::f32) | ||
+ && utils::everyone_is( | ||
+ data_type::bf16, weights_md()->data_type) | ||
+ && platform::has_data_type_support( | ||
+ data_type::bf16)); | ||
+ | ||
const bool is_weights_md_format_ok | ||
= utils::one_of(weights_format_kind_received, | ||
format_kind::any, format_kind::blocked); | ||
bool ok = is_dense_data() | ||
- && utils::one_of(true, is_fp32_ok, is_fp16_ok) | ||
+ && utils::one_of( | ||
+ true, is_fp32_ok, is_fp16_ok, is_fp32_bf16_ok) | ||
&& !has_zero_dim_memory() && is_weights_md_format_ok | ||
&& set_default_formats() | ||
&& attr()->has_default_values( | ||
-- | ||
2.34.1 |
100 changes: 100 additions & 0 deletions
100
...tsl/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
index 451cc78d52..ab13efb9b2 100644 | ||
--- a/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
+++ b/src/cpu/aarch64/matmul/acl_matmul.hpp | ||
@@ -67,6 +67,8 @@ struct acl_matmul_t : public primitive_t { | ||
|
||
status_t init(engine_t *engine) { | ||
using smask_t = primitive_attr_t::skip_mask_t; | ||
+ const format_kind_t weights_format_kind_received | ||
+ = weights_md_.format_kind; | ||
const bool is_fp32_ok | ||
= utils::everyone_is(data_type::f32, src_md()->data_type, | ||
weights_md()->data_type, dst_md()->data_type, | ||
@@ -76,18 +78,20 @@ struct acl_matmul_t : public primitive_t { | ||
= utils::everyone_is(data_type::f16, src_md()->data_type, | ||
weights_md()->data_type, dst_md()->data_type) | ||
&& platform::has_data_type_support(data_type::f16); | ||
+ const bool is_weights_md_format_ok | ||
+ = utils::one_of(weights_format_kind_received, | ||
+ format_kind::any, format_kind::blocked); | ||
bool ok = is_dense_data() | ||
&& utils::one_of(true, is_fp32_ok, is_fp16_ok) | ||
- && !has_zero_dim_memory() | ||
- && weights_md_.format_kind == format_kind::any | ||
+ && !has_zero_dim_memory() && is_weights_md_format_ok | ||
&& set_default_formats() | ||
&& attr()->has_default_values( | ||
smask_t::oscale | smask_t::post_ops) | ||
&& attr_oscale_ok() && !has_runtime_dims_or_strides(); | ||
if (!ok) return status::unimplemented; | ||
|
||
- CHECK(acl_matmul_utils::init_conf_matmul( | ||
- amp_, src_md_, weights_md_, dst_md_, *desc(), *attr())); | ||
+ CHECK(acl_matmul_utils::init_conf_matmul(amp_, src_md_, weights_md_, | ||
+ dst_md_, *desc(), *attr(), weights_format_kind_received)); | ||
|
||
arm_compute::ActivationLayerInfo act_info; | ||
CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, act_info)); | ||
diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp | ||
index a314d96384..027f915a8a 100644 | ||
--- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp | ||
+++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp | ||
@@ -27,7 +27,8 @@ namespace acl_matmul_utils { | ||
|
||
status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, | ||
memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, | ||
- const primitive_attr_t &attr) { | ||
+ const primitive_attr_t &attr, | ||
+ format_kind_t weights_format_kind_received) { | ||
|
||
const memory_desc_wrapper src_d(&src_md); | ||
const memory_desc_wrapper wei_d(&wei_md); | ||
@@ -128,9 +129,16 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, | ||
for (dim_t i = K_dim - 1; i >= 0; --i) | ||
batch_dims.push_back(i); | ||
|
||
+ const memory_desc_t weights_md_received = wei_md; | ||
acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, | ||
expected_weight_format, K_dim, N_dim, {}, batch_dims); | ||
|
||
+ ACL_CHECK_SUPPORT((weights_format_kind_received == format_kind::blocked) | ||
+ && !(dnnl_memory_desc_equal(&weights_md_received, &wei_md)), | ||
+ "specified blocked format not supported by ACL, use " | ||
+ "format_kind_t::any to find a supported blocked format for " | ||
+ "your platform"); | ||
+ | ||
return status::success; | ||
} | ||
|
||
diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp | ||
index 67bb2e78eb..5ba4241abc 100644 | ||
--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp | ||
+++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp | ||
@@ -52,7 +52,8 @@ namespace acl_matmul_utils { | ||
|
||
status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, | ||
memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, | ||
- const primitive_attr_t &attr); | ||
+ const primitive_attr_t &attr, | ||
+ format_kind_t weights_format_kind_received); | ||
|
||
} // namespace acl_matmul_utils | ||
|
||
-- | ||
2.34.1 |
96 changes: 96 additions & 0 deletions
96
third_party/tsl/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/aarch64/acl_post_ops.cpp | ||
index ea4bb200ec..3eb53b81bd 100644 | ||
--- a/src/cpu/aarch64/acl_post_ops.cpp | ||
+++ b/src/cpu/aarch64/acl_post_ops.cpp | ||
@@ -24,7 +24,7 @@ namespace aarch64 { | ||
|
||
status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const { | ||
|
||
- int post_op_index = 0; | ||
+ int post_op_index = post_op_start_index_; | ||
|
||
// As these are post ops, this src will also be our dst. If we have a sum | ||
// post op, the src/dst will start off in a temporary, then change to | ||
diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp | ||
index 7b59ad71d3..ceaa95b73a 100644 | ||
--- a/src/cpu/aarch64/acl_post_ops.hpp | ||
+++ b/src/cpu/aarch64/acl_post_ops.hpp | ||
@@ -32,7 +32,9 @@ struct acl_post_ops_t { | ||
// init the acl_post_ops_t. Note that this function modifies the passed in | ||
// post ops by setting the preferred memory formats | ||
status_t init(engine_t *engine, post_ops_t &post_ops, | ||
- const memory_desc_t &dst_md) { | ||
+ const memory_desc_t &dst_md, int post_op_start_index = 0) { | ||
+ | ||
+ post_op_start_index_ = post_op_start_index; | ||
|
||
CHECK(post_ops.set_default_formats(&dst_md)); | ||
dst_data_type = dst_md.data_type; | ||
@@ -41,7 +43,7 @@ struct acl_post_ops_t { | ||
sum_index = -1; | ||
post_op_primitives = {}; | ||
|
||
- for (int i = 0; i < post_ops.len(); i++) { | ||
+ for (int i = post_op_start_index; i < post_ops.len(); i++) { | ||
auto &po = post_ops.entry_[i]; | ||
|
||
if (po.is_sum()) { | ||
@@ -135,7 +137,8 @@ struct acl_post_ops_t { | ||
// formats | ||
status_t init(engine_t *engine, post_ops_t &base_post_ops, | ||
const memory_desc_t &dst_md, | ||
- arm_compute::ActivationLayerInfo &act_info_to_fuse) { | ||
+ arm_compute::ActivationLayerInfo &act_info_to_fuse, | ||
+ int post_op_start_index = 0) { | ||
|
||
CHECK(base_post_ops.set_default_formats(&dst_md)); | ||
dst_data_type = dst_md.data_type; | ||
@@ -149,18 +152,11 @@ struct acl_post_ops_t { | ||
"eltwise post op scale must be 1 (no scale)"); | ||
CHECK(acl_utils::convert_to_acl_act(first_po, act_info_to_fuse)); | ||
|
||
- // Copy all but the first, because it has been fused | ||
- post_ops_t post_ops; | ||
- for (int idx = 1; idx < base_post_ops.len(); ++idx) { | ||
- // Construct empty entry then copy, so that we can check for failure | ||
- post_ops.entry_.emplace_back(); | ||
- post_ops.entry_.back().copy_from(base_post_ops.entry_[idx]); | ||
- } | ||
- return init(engine, post_ops, dst_md); | ||
- | ||
+ // post_op_start_index + 1 to skip the fused eltwise | ||
+ return init(engine, base_post_ops, dst_md, post_op_start_index + 1); | ||
} else { | ||
// Nothing to fuse, just copy all post ops | ||
- return init(engine, base_post_ops, dst_md); | ||
+ return init(engine, base_post_ops, dst_md, post_op_start_index); | ||
} | ||
} | ||
|
||
@@ -179,6 +175,9 @@ struct acl_post_ops_t { | ||
private: | ||
// Index of the sum post op if there is one, < 0 means no sum | ||
int sum_index = -1; | ||
+ // Index of the first post op this primitive executes. This is typically the | ||
+ // number of post ops which were fused. | ||
+ int post_op_start_index_ = 0; | ||
data_type_t dst_data_type; | ||
// Vector of primitives used to execute the post ops. They are constructed | ||
// in init to be either acl_binary_t (for sum, add, sub, div, mul, min and | ||
-- | ||
2.34.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters