Skip to content

Commit

Permalink
Fix unit-test failure.
Browse files Browse the repository at this point in the history
Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback committed Dec 1, 2024
1 parent 5788c73 commit ef968a6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
// correctly and we need to do it manually
#ifdef ENABLE_ONEDNN_FOR_GPU
for (auto& node : p.get_processing_order()) {
if (node->is_type<gemm>() && node->get_preferred_impl_type() == impl_types::onednn) {
if ((node->is_type<gemm>() || node->is_type<fully_connected>()) && node->get_preferred_impl_type() == impl_types::onednn) {
for (const auto& fused_prim : node->get_fused_primitives()) {
if (fused_prim.is_type<eltwise>() &&
one_of(fused_prim.typed_desc<eltwise>()->mode, {eltwise_mode::sum, eltwise_mode::sub, eltwise_mode::prod})) {
Expand Down
7 changes: 4 additions & 3 deletions src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1549,14 +1549,15 @@ void program_node::create_onednn_primitive_attributes(
if (prim->input_size == in_pshape.size()) {
ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
} else {
if (in_pshape.size() > prim->input_size)
ones_to_add = in_pshape.size() - prim->input_size;
if (prim->input_size == 3)
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
ones_to_add = std::max(in_pshape.size(), prim->input_size) - std::min(in_pshape.size(), prim->input_size);
}
if (ones_to_add > 0) {
layout new_layout = in;
ov::PartialShape new_input_pshape;
auto last = in_pshape.begin() + in_pshape.size();
if (prim->input_size != in_pshape.size())
if (in_pshape.size() > prim->input_size)
last -= ones_to_add;
std::vector<ov::Dimension> dims(in_pshape.begin(), last);
new_input_pshape = ov::PartialShape(dims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ class FullyConnectedFusingTestOneDNN : public BaseFusingTest<fully_connected_tes
#define CASE_FC_FP32_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::f32, format::yxfb, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3 { 2, 32 }, { 2, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx

#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx

#define CASE_FC_U8S8_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_U8S8_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
Expand Down

0 comments on commit ef968a6

Please sign in to comment.