From 623281646b1a3b3cbd9c437e9e85aaddca1ba193 Mon Sep 17 00:00:00 2001 From: Aleksandr Pertovsky Date: Thu, 24 Jun 2021 06:59:41 +0300 Subject: [PATCH 1/2] [CPU] DeformableConvolution-8 --- .../nodes/mkldnn_def_conv_node.cpp | 470 ++++++++++++------ .../nodes/mkldnn_def_conv_node.h | 16 +- 2 files changed, 341 insertions(+), 145 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp index a2fae182a52f70..9e1b2fe3c555d3 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp @@ -45,6 +45,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); if (jcp_.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + if (jcp_.with_modulation) + mov(reg_modulation, ptr[this->param1 + GET_OFF(modulation)]); mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); mov(reg_input_buffer, ptr[this->param1 + GET_OFF(buf)]); mov(reg_oh_pos, ptr[param1 + GET_OFF(oh_pos)]); @@ -71,6 +73,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ reg64_t reg_def_off = r9; reg64_t reg_kernel = r10; reg64_t reg_bias = r11; + reg64_t reg_modulation = rcx; reg64_t reg_output = r12; reg64_t reg_oh_pos = r13; reg64_t aux_reg_bias = rsi; @@ -82,13 +85,13 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ reg64_t reg_ic_iter = rbx; reg64_t reg_oc_work = reg_ic_iter; reg64_t aux_reg_def_off = reg_bias; - reg64_t reg_input_buffer = abi_not_param1; reg64_t aux_reg_input_buffer = r14; reg32_t reg_tmp_32 = r15d; reg64_t reg_tmp_64 = r15; reg64_t reg_table = rbp; + reg64_t reg_input_buffer = aux_reg_input; reg64_t aux_reg_kernel = reg_table; - reg64_t aux2_reg_kernel = r15; + reg64_t aux2_reg_kernel = reg_tmp_64; reg64_t aux2_reg_input_buffer = aux_reg_bias; reg64_t aux3_reg_input_buffer = reg_input; @@ -119,6 +122,9 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ add(reg_input, jcp_.ur_w * jcp_.stride_w * jcp_.ic * jcp_.typesize_in); add(reg_def_off, jcp_.ur_w * jcp_.typesize_off); + if (jcp_.with_modulation) { + add(reg_modulation, jcp_.ur_w * jcp_.typesize_modulation); + } add(reg_output, jcp_.ur_w * jcp_.oc * jcp_.typesize_out); add(reg_ow_pos, jcp_.ur_w); @@ -217,7 +223,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Label exit; push(reg_oc_work); - push(aux_reg_bias); + if (jcp_.with_bias) + push(aux_reg_bias); mov(aux2_reg_kernel, aux_reg_kernel); mov(aux2_reg_input_buffer, reg_input_buffer); @@ -243,8 +250,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ apply_filter(ow_step, oc_blocks_step, oc_step, jcp_.ic % jcp_.ic_block); } } - - pop(aux_reg_bias); + if (jcp_.with_bias) + pop(aux_reg_bias); pop(reg_oc_work); } @@ -256,6 +263,9 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ mov(aux_reg_def_off, reg_def_off); mov(aux_reg_input, reg_input); mov(aux2_reg_input_buffer, aux_reg_input_buffer); + if (jcp_.with_modulation) { + push(reg_modulation); + } xor_(reg_dg_iter, reg_dg_iter); const int ic_per_def_group = jcp_.ic / jcp_.dg; @@ -271,10 +281,14 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Label ic_loop_tail; Label ic_loop_zeros; Label loop_end; - Label h_sec_opt; - Label h_sec_opt_exit; - Label w_sec_opt; - Label w_sec_opt_exit; + Label v1_condition_end_main; + Label v2_condition_end_main; + Label v3_condition_end_main; + Label v4_condition_end_main; + Label v1_condition_end_tail; + Label v2_condition_end_tail; + Label v3_condition_end_tail; + Label v4_condition_end_tail; mov(aux2_reg_input, aux_reg_input); add(aux2_reg_input, (ow * jcp_.stride_w * jcp_.ic) * jcp_.typesize_in); @@ -287,45 +301,48 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Xmm xmm_map_h = Xmm(2); Xmm xmm_ih_in = Xmm(4); Xmm xmm_ih_im = Xmm(1); - Xmm xmm_cur_height = xmm_ih_im; Xmm xmm_h_low = xmm_ih_in; - Xmm xmm_h_high = xmm_cur_height; + Xmm xmm_h_high = xmm_ih_im; Xmm xmm_lh = xmm_map_h; Xmm xmm_hh = Xmm(3); Xmm xmm_map_w = Xmm(6); Xmm xmm_iw_in = Xmm(8); Xmm xmm_iw_im = Xmm(5); - Xmm xmm_cur_width = xmm_iw_im; Xmm xmm_w_low = xmm_iw_in; - Xmm xmm_w_high = xmm_cur_width; + Xmm xmm_w_high = xmm_iw_im; Xmm xmm_lw = xmm_map_w; Xmm xmm_hw = Xmm(7); - Xmm xmm_v1_off = Xmm(9); - Xmm xmm_v2_off = Xmm(10); - Xmm xmm_v3_off = Xmm(11); - Xmm xmm_v4_off = Xmm(12); + Xmm xmm_v1_off = xmm_lh; + Xmm xmm_v2_off = xmm_hh; + Xmm xmm_v3_off = xmm_lw; + Xmm xmm_v4_off = xmm_hw; + + Xmm xmm_cur_height = Xmm(13); + Xmm xmm_cur_width = Xmm(14); - Xmm xmm_w1 = xmm_h_low; - Xmm xmm_w2 = xmm_h_high; - Xmm xmm_w3 = xmm_w_low; - Xmm xmm_w4 = xmm_w_high; + Xmm xmm_w1 = Xmm(9); + Xmm xmm_w2 = Xmm(10); + Xmm xmm_w3 = Xmm(11); + Xmm xmm_w4 = Xmm(12); - Xmm xmm_v1 = xmm_lh; - Xmm xmm_v2 = xmm_hh; - Xmm xmm_v3 = xmm_lw; - Xmm xmm_v4 = xmm_hw; + Xmm xmm_v1 = xmm_v1_off; + Xmm xmm_v2 = xmm_v2_off; + Xmm xmm_v3 = xmm_v3_off; + Xmm xmm_v4 = xmm_v4_off; - Vmm vmm_w1 = Vmm(xmm_h_low.getIdx()); - Vmm vmm_w2 = Vmm(xmm_h_high.getIdx()); - Vmm vmm_w3 = Vmm(xmm_w_low.getIdx()); - Vmm vmm_w4 = Vmm(xmm_w_high.getIdx()); + Vmm vmm_w1 = Vmm(xmm_w1.getIdx()); + Vmm vmm_w2 = Vmm(xmm_w2.getIdx()); + Vmm vmm_w3 = Vmm(xmm_w3.getIdx()); + Vmm vmm_w4 = Vmm(xmm_w4.getIdx()); - Vmm vmm_v1 = Vmm(xmm_lh.getIdx()); - Vmm vmm_v2 = Vmm(xmm_hh.getIdx()); - Vmm vmm_v3 = Vmm(xmm_lw.getIdx()); - Vmm vmm_v4 = Vmm(xmm_hw.getIdx()); + Vmm vmm_v1 = Vmm(xmm_v1_off.getIdx()); + Vmm vmm_v2 = Vmm(xmm_v2_off.getIdx()); + Vmm vmm_v3 = Vmm(xmm_v3_off.getIdx()); + Vmm vmm_v4 = Vmm(xmm_v4_off.getIdx()); + + // condition check size_t def_off_h = ((2 * (kh * jcp_.kw + kw) + 0) * jcp_.oh * jcp_.ow) + ow; mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_h * jcp_.typesize_off]); @@ -356,6 +373,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ size_t def_off_w = ((2 * (kh * jcp_.kw + kw) + 1) * jcp_.oh * jcp_.ow) + ow; mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_w * jcp_.typesize_off]); + movq(xmm_tmp, reg_tmp_64); mov(reg_tmp_32, float2int(static_cast((kw * (jcp_.dilate_w + 1))))); movq(xmm_map_w, reg_tmp_64); @@ -380,83 +398,53 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ cmp(reg_tmp_32, 0); je(init_with_zeros, T_NEAR); + // interpolation calculation movd(xmm_cur_height, table_val(3)); psubd(xmm_cur_height, xmm_ih_in); roundps(xmm_h_low, xmm_map_h, 1); cvtps2dq(xmm_h_low, xmm_h_low); + maxss(xmm_h_low, table_val(0)); - movups(xmm_tmp, xmm_cur_height); - pcmpgtd(xmm_tmp, xmm_h_low); - - movq(reg_tmp_64, xmm_tmp); - cmp(reg_tmp_32, 0); - jne(h_sec_opt, T_NEAR); - - movups(xmm_h_low, xmm_cur_height); - movups(xmm_h_high, xmm_h_low); - jmp(h_sec_opt_exit); - - L(h_sec_opt); - - movups(xmm_h_high, xmm_h_low); - paddd(xmm_h_high, table_val(5)); - - L(h_sec_opt_exit); - - cvtdq2ps(xmm_tmp, xmm_h_low); - subss(xmm_lh, xmm_tmp); - movss(xmm_hh, table_val(5)); - cvtdq2ps(xmm_hh, xmm_hh); - subss(xmm_hh, xmm_lh); - + if (jcp_.with_bi_pad) { + movdqu(xmm_h_high, xmm_h_low); + paddd(xmm_h_high, table_val(5)); + } else { + roundps(xmm_h_high, xmm_map_h, 2); + cvtps2dq(xmm_h_high, xmm_h_high); + minss(xmm_h_high, xmm_cur_height); + } movd(xmm_cur_width, table_val(4)); psubd(xmm_cur_width, xmm_iw_in); roundps(xmm_w_low, xmm_map_w, 1); cvtps2dq(xmm_w_low, xmm_w_low); + maxss(xmm_w_low, table_val(0)); - movups(xmm_tmp, xmm_cur_width); - pcmpgtd(xmm_tmp, xmm_w_low); - - movq(reg_tmp_64, xmm_tmp); - cmp(reg_tmp_32, 0); - jne(w_sec_opt, T_NEAR); - - movups(xmm_w_low, xmm_cur_width); - movups(xmm_w_high, xmm_w_low); - jmp(w_sec_opt_exit); - - L(w_sec_opt); - - movups(xmm_w_high, xmm_w_low); - paddd(xmm_w_high, table_val(5)); - - L(w_sec_opt_exit); + if (jcp_.with_bi_pad) { + movdqu(xmm_w_high, xmm_w_low); + paddd(xmm_w_high, table_val(5)); + } else { + roundps(xmm_w_high, xmm_map_w, 2); + cvtps2dq(xmm_w_high, xmm_w_high); + minss(xmm_w_high, xmm_cur_width); + } cvtdq2ps(xmm_tmp, xmm_w_low); subss(xmm_lw, xmm_tmp); + movss(xmm_hw, table_val(5)); cvtdq2ps(xmm_hw, xmm_hw); subss(xmm_hw, xmm_lw); + cvtdq2ps(xmm_tmp, xmm_h_low); + subss(xmm_lh, xmm_tmp); - movups(xmm_v1_off, table_val(2)); - cvtps2dq(xmm_v1_off, xmm_v1_off); - movups(xmm_v3_off, xmm_v1_off); - - pmulld(xmm_v1_off, xmm_h_low); - movups(xmm_v2_off, xmm_v1_off); - paddd(xmm_v1_off, xmm_w_low); - paddd(xmm_v2_off, xmm_w_high); - - pmulld(xmm_v3_off, xmm_h_high); - movups(xmm_v4_off, xmm_v3_off); - paddd(xmm_v3_off, xmm_w_low); - paddd(xmm_v4_off, xmm_w_high); - + movss(xmm_hh, table_val(5)); + cvtdq2ps(xmm_hh, xmm_hh); + subss(xmm_hh, xmm_lh); movss(xmm_w1, xmm_hh); mulss(xmm_w1, xmm_hw); @@ -487,29 +475,97 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ movq(reg_tmp_64, xmm_v1_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); + // w_low >= 0 + movups(xmm_tmp, xmm_w_low); + pcmpgtd(xmm_tmp, table_val(0)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // jne(v1_condition_end_main, T_NEAR); + + // h_low >= 0 + movups(xmm_tmp, xmm_h_low); + pcmpgtd(xmm_tmp, table_val(0)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // jne(v1_condition_end_main, T_NEAR); + uni_vmovups(vmm_v1, ptr[reg_tmp_64]); uni_vmulps(vmm_v1, vmm_v1, vmm_w1); + L(v1_condition_end_main); + pmovsxdq(xmm_v2_off, xmm_v2_off); movq(reg_tmp_64, xmm_v2_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); + + // w_high <= cur_width - 1 + movups(xmm_tmp, xmm_w_high); + psubd(xmm_tmp, table_val(0)); + pcmpgtd(xmm_tmp, table_val(4)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // je(v2_condition_end_main, T_NEAR); + + // h_low >= 0 + movups(xmm_tmp, xmm_h_low); + pcmpgtd(xmm_tmp, table_val(0)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // jne(v2_condition_end_main, T_NEAR); + uni_vmovups(vmm_v2, ptr[reg_tmp_64]); uni_vmulps(vmm_v2, vmm_v2, vmm_w2); + L(v2_condition_end_main); pmovsxdq(xmm_v3_off, xmm_v3_off); movq(reg_tmp_64, xmm_v3_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); + + // w_low >= 0 + movups(xmm_tmp, xmm_w_low); + pcmpgtd(xmm_tmp, table_val(0)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // jne(v3_condition_end_main, T_NEAR); + + // h_high <= cur_height + movups(xmm_tmp, xmm_h_high); + psubd(xmm_tmp, table_val(0)); + pcmpgtd(xmm_tmp, table_val(3)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // je(v3_condition_end_main, T_NEAR); + uni_vmovups(vmm_v3, ptr[reg_tmp_64]); uni_vmulps(vmm_v3, vmm_v3, vmm_w3); + L(v3_condition_end_main); pmovsxdq(xmm_v4_off, xmm_v4_off); movq(reg_tmp_64, xmm_v4_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); + + // w_high <= cur_width + movups(xmm_tmp, xmm_w_high); + psubd(xmm_tmp, table_val(0)); + pcmpgtd(xmm_tmp, table_val(3)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // je(v4_condition_end_main, T_NEAR); + + // h_high <= cur_height + movups(xmm_tmp, xmm_h_high); + psubd(xmm_tmp, table_val(0)); + pcmpgtd(xmm_tmp, table_val(4)); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + // je(v4_condition_end_main, T_NEAR); + uni_vmovups(vmm_v4, ptr[reg_tmp_64]); uni_vmulps(vmm_v4, vmm_v4, vmm_w4); + L(v4_condition_end_main); uni_vaddps(vmm_v1, vmm_v1, vmm_v2); uni_vaddps(vmm_v1, vmm_v1, vmm_v3); @@ -529,37 +585,123 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ size_t input_buffer_off = (size_t) kh * jcp_.kw * jcp_.ic + kw * jcp_.ic; + movss(xmm_v1, table_val(0)); + // w_low >= 0 + movq(reg_tmp_64, xmm_w_low); + cmp(reg_tmp_32, 0); + jl(v1_condition_end_tail, T_NEAR); + + // h_low >= 0 + movq(reg_tmp_64, xmm_h_low); + cmp(reg_tmp_32, 0); + jl(v1_condition_end_tail, T_NEAR); + + movups(xmm_v1_off, table_val(2)); + cvtps2dq(xmm_v1_off, xmm_v1_off); + pmulld(xmm_v1_off, xmm_h_low); + paddd(xmm_v1_off, xmm_w_low); pmovsxdq(xmm_v1_off, xmm_v1_off); + movq(reg_tmp_64, xmm_v1_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); movss(xmm_v1, ptr[reg_tmp_64]); mulss(xmm_v1, xmm_w1); - + L(v1_condition_end_tail); + + movss(xmm_v2, table_val(0)); + // w_high <= cur_width - 1 + movq(xmm_tmp, xmm_w_high); + pcmpgtd(xmm_tmp, xmm_cur_width); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + jne(v2_condition_end_tail, T_NEAR); + + // h_low >= 0 + movq(reg_tmp_64, xmm_h_low); + cmp(reg_tmp_32, 0); + jl(v2_condition_end_tail, T_NEAR); + + + movups(xmm_v2_off, table_val(2)); + cvtps2dq(xmm_v2_off, xmm_v2_off); + pmulld(xmm_v2_off, xmm_h_low); + paddd(xmm_v2_off, xmm_w_high); pmovsxdq(xmm_v2_off, xmm_v2_off); + movq(reg_tmp_64, xmm_v2_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); movss(xmm_v2, ptr[reg_tmp_64]); mulss(xmm_v2, xmm_w2); - + L(v2_condition_end_tail); + + movss(xmm_v3, table_val(0)); + // w_low >= 0 + movq(reg_tmp_64, xmm_w_low); + cmp(reg_tmp_32, 0); + jl(v3_condition_end_tail, T_NEAR); + + // h_high <= cur_height - 1 + movq(xmm_tmp, xmm_h_high); + pcmpgtd(xmm_tmp, xmm_cur_height); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + jne(v3_condition_end_tail, T_NEAR); + + movups(xmm_v3_off, table_val(2)); + cvtps2dq(xmm_v3_off, xmm_v3_off); + pmulld(xmm_v3_off, xmm_h_high); + paddd(xmm_v3_off, xmm_w_low); pmovsxdq(xmm_v3_off, xmm_v3_off); + movq(reg_tmp_64, xmm_v3_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); movss(xmm_v3, ptr[reg_tmp_64]); mulss(xmm_v3, xmm_w3); + L(v3_condition_end_tail); + + movss(xmm_v4, table_val(0)); + // w_high <= cur_width - 1 + movq(xmm_tmp, xmm_w_high); + pcmpgtd(xmm_tmp, xmm_cur_width); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + jne(v4_condition_end_tail, T_NEAR); + + // h_high <= cur_height - 1 + movq(xmm_tmp, xmm_h_high); + pcmpgtd(xmm_tmp, xmm_cur_height); + movq(reg_tmp_64, xmm_tmp); + cmp(reg_tmp_32, 0); + jne(v4_condition_end_tail, T_NEAR); + + + movups(xmm_v4_off, table_val(2)); + cvtps2dq(xmm_v4_off, xmm_v4_off); + pmulld(xmm_v4_off, xmm_h_high); + paddd(xmm_v4_off, xmm_w_high); pmovsxdq(xmm_v4_off, xmm_v4_off); + movq(reg_tmp_64, xmm_v4_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); add(reg_tmp_64, aux2_reg_input); + movss(xmm_v4, ptr[reg_tmp_64]); mulss(xmm_v4, xmm_w4); + L(v4_condition_end_tail); addss(xmm_v1, xmm_v2); addss(xmm_v1, xmm_v3); addss(xmm_v1, xmm_v4); + + if (jcp_.with_modulation) { + size_t modulation_offset = ((kh * jcp_.kw + kw) * jcp_.oh * jcp_.ow) + ow; + mulss(xmm_v1, ptr[reg_modulation + modulation_offset * jcp_.typesize_modulation]); + } + movss(ptr[aux3_reg_input_buffer + input_buffer_off * jcp_.typesize_in], xmm_v1); add(aux2_reg_input, jcp_.typesize_in); @@ -593,13 +735,18 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ } add(aux_reg_def_off, 2 * jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_off); + if (jcp_.with_modulation) { + add(reg_modulation, jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_modulation); + } add(aux_reg_input, ic_per_def_group * jcp_.typesize_in); add(aux2_reg_input_buffer, ic_per_def_group * jcp_.typesize_in); inc(reg_dg_iter); jmp(dg_loop, T_NEAR); } - L(dg_loop_end); + if (jcp_.with_modulation) { + pop(reg_modulation); + } } void store_output(int ow_step, int oc_blocks_step, int oc_step) { @@ -679,22 +826,27 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ mov(aux_reg_input_buffer, reg_input_buffer); push(reg_output); - push(reg_bias); + if (jcp_.with_bias) + push(reg_bias); push(reg_input); push(reg_kernel); + push(reg_input_buffer); interpolate_input(ow_step); + pop(reg_input_buffer); pop(reg_kernel); pop(reg_input); - pop(reg_bias); + if (jcp_.with_bias) + pop(reg_bias); pop(reg_output); push(reg_ow_pos); mov(aux_reg_kernel, reg_kernel); mov(aux_reg_output, reg_output); - mov(aux_reg_bias, reg_bias); + if (jcp_.with_bias) + mov(aux_reg_bias, reg_bias); mov(reg_oc_work, jcp_.oc); @@ -707,7 +859,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ add(aux_reg_kernel, jcp_.nb_oc_blocking * jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block * jcp_.typesize_in); add(aux_reg_output, jcp_.nb_oc_blocking * jcp_.oc_block * jcp_.typesize_out); - add(aux_reg_bias, jcp_.nb_oc_blocking * jcp_.oc_block * jcp_.typesize_bia); + if (jcp_.with_bias) + add(aux_reg_bias, jcp_.nb_oc_blocking * jcp_.oc_block * jcp_.typesize_bia); sub(reg_oc_work, jcp_.nb_oc_blocking * jcp_.oc_block); jmp(oc_unrolled_loop, T_NEAR); @@ -722,7 +875,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ add(aux_reg_kernel, jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block * jcp_.typesize_in); add(aux_reg_output, jcp_.oc_block * jcp_.typesize_out); - add(aux_reg_bias, jcp_.oc_block * jcp_.typesize_bia); + if (jcp_.with_bias) + add(aux_reg_bias, jcp_.oc_block * jcp_.typesize_bia); sub(reg_oc_work, jcp_.oc_block); jmp(oc_main_loop, T_NEAR); @@ -741,9 +895,9 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ bool MKLDNNDeformableConvolutionNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - const auto defConvNode = ngraph::as_type_ptr(op); + const auto defConvNode = ngraph::as_type_ptr(op); if (!defConvNode) { - errorMessage = "Node is not an instance of DeformableConvolution form the operation set v1."; + errorMessage = "Node is not an instance of DeformableConvolution form the operation set v8."; return false; } } catch (...) { @@ -759,11 +913,11 @@ MKLDNNDeformableConvolutionNode::MKLDNNDeformableConvolutionNode(const std::shar if (!isSupportedOperation(op, errorMessage)) { IE_THROW(NotImplemented) << errorMessage; } - auto defConvNode = ngraph::as_type_ptr(op); + auto defConvNode = ngraph::as_type_ptr(op); group = defConvNode->get_group(); deformable_group = defConvNode->get_deformable_group(); - + with_bilinear_pad = defConvNode->get_use_bilinear_interpolation_padding(); auto& strides = defConvNode->get_strides(); for (int i = 0; i < strides.size(); i++) { stride.push_back(strides[i]); @@ -780,7 +934,7 @@ MKLDNNDeformableConvolutionNode::MKLDNNDeformableConvolutionNode(const std::shar void MKLDNNDeformableConvolutionNode::getSupportedDescriptors() { std::string errorPrefix = "DeformableConvolution layer with name '" + getName() + "' "; - if (getParentEdges().size() != 3) + if (getParentEdges().size() != 3 && getParentEdges().size() != 4) IE_THROW() << errorPrefix << "has incorrect number of input edges"; if (getChildEdges().empty()) IE_THROW() << errorPrefix << "has incorrect number of output edges"; @@ -806,15 +960,20 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; + size_t inputsNumber = getOriginalInputsNumber(); InferenceEngine::LayerConfig config; config.dynBatchSupport = false; - config.inConfs.resize(3); + config.inConfs.resize(inputsNumber); config.inConfs[0].constant = false; config.inConfs[0].inPlace = -1; config.inConfs[1].constant = false; config.inConfs[1].inPlace = -1; - config.inConfs[1].constant = false; - config.inConfs[1].inPlace = -1; + config.inConfs[2].constant = false; + config.inConfs[2].inPlace = -1; + if (inputsNumber > 3) { + config.inConfs[3].constant = false; + config.inConfs[3].inPlace = -1; + } config.outConfs.resize(1); config.outConfs[0].constant = false; @@ -841,6 +1000,9 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() { config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, dataFormat); config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, offFormat); config.inConfs[2].desc = MKLDNNMemoryDesc(getParentEdgeAt(2)->getDims(), memory::data_type::f32, weiFormat); + if (inputsNumber > 3) { + config.inConfs[3].desc = MKLDNNMemoryDesc(getParentEdgeAt(3)->getDims(), memory::data_type::f32, memory::format_tag::nchw); + } config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), memory::data_type::f32, dataFormat); supportedPrimitiveDescriptors.push_back({config, impl_type, dataFormat}); } else { @@ -850,6 +1012,10 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() { config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, memory::format_tag::nchw); config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, memory::format_tag::nchw); config.inConfs[2].desc = MKLDNNMemoryDesc(getParentEdgeAt(2)->getDims(), memory::data_type::f32, memory::format_tag::oihw); + if (inputsNumber > 3) { + auto dims = getParentEdgeAt(3)->getDims(); + config.inConfs[3].desc = MKLDNNMemoryDesc(getParentEdgeAt(3)->getDims(), memory::data_type::f32, memory::format_tag::nchw); + } config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), memory::data_type::f32, memory::format_tag::nchw); supportedPrimitiveDescriptors.push_back({config, impl_type, weiFormat}); } @@ -868,6 +1034,7 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() { jcp.dg = deformable_group; jcp.ngroups = group; + jcp.mb = srcDims[0]; jcp.oc = dstDims[1] / jcp.ngroups; @@ -892,6 +1059,8 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() { jcp.dilate_w = dilation[1]; jcp.with_bias = false; + jcp.with_bi_pad = with_bilinear_pad; + jcp.with_modulation = getParentEdges().size() > 3; const int simd_w = mayiuse(cpu::x64::avx512_common) ? 16 : 8; jcp.ic_block = simd_w; @@ -904,6 +1073,7 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() { jcp.typesize_in = sizeof(float); jcp.typesize_off = sizeof(float); jcp.typesize_out = sizeof(float); + jcp.typesize_modulation = sizeof(float); jcp.ur_w = mayiuse(cpu::x64::avx512_common) ? 6 : 3; jcp.nb_oc_blocking = !mayiuse(cpu::x64::avx2) ? 2 : 4; @@ -924,9 +1094,9 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() { void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const float* offsets, const float* weights, float* dst, const std::vector& src_strides, const std::vector& off_strides, - const std::vector& wei_strides, const std::vector& dst_strides) { + const std::vector& wei_strides, const std::vector& dst_strides, + const float* modulation, const std::vector& modulation_strides) { const bool with_groups = jcp.ngroups > 1; - const int G = jcp.ngroups; const int MB = jcp.mb; const int OH = jcp.oh; @@ -950,8 +1120,9 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f const int DG = jcp.dg; - const int channel_per_deformable_group = IC * G / DG; + const int channel_per_deformable_group = (IC * G) / DG; + const bool with_bi_pad = jcp.with_bi_pad; auto ker = [=](int g, int mb, int oc, int oh, int ow) { float d = 0; const int h_in = oh * KSH - padT; @@ -961,54 +1132,54 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f const float *data_im_ptr = src + mb * src_strides[0] + (g * IC + ic) * src_strides[1] + h_in * src_strides[2] + w_in * src_strides[3]; const int deformable_group_index = ic / channel_per_deformable_group; const float *data_offset_ptr = offsets + mb * off_strides[0] + (deformable_group_index * 2 * KH * KW) * off_strides[1]; + const float *modulation_offset_ptr = nullptr; + if (modulation != nullptr) { + modulation_offset_ptr = modulation + mb * modulation_strides[0] + (deformable_group_index * KH * KW) * modulation_strides[1]; + } + for (int kh = 0; kh < KH; kh++) { for (int kw = 0; kw < KW; kw++) { const size_t data_offset_h_index = 2 * (kh * KW + kw) * off_strides[1] + oh * off_strides[2] + ow * off_strides[3]; const size_t data_offset_w_index = (2 * (kh * KW + kw) + 1) * off_strides[1] + oh * off_strides[2] + ow * off_strides[3]; const float offset_h = data_offset_ptr[data_offset_h_index]; const float offset_w = data_offset_ptr[data_offset_w_index]; - float val = 0.0f; - const float h_im = h_in + kh * (KDH + 1) + offset_h; - const float w_im = w_in + kw * (KDW + 1) + offset_w; + float map_h = kh * (KDH + 1) + offset_h; // kernel index with offset + float map_w = kw * (KDW + 1) + offset_w; // kernel index with offset + const float h_im = h_in + map_h; // absolute pixel index with offset + const float w_im = w_in + map_w; // absolute pixel index with offset if (h_im >= 0 && w_im >= 0 && h_im < IH && w_im < IW) { - float map_h = kh * (KDH + 1) + offset_h; - float map_w = kw * (KDW + 1) + offset_w; const int cur_height = IH - h_in; const int cur_width = IW - w_in; - int h_low = static_cast(floorf(map_h)); - int w_low = static_cast(floorf(map_w)); - int h_high; - int w_high; - if (h_low >= cur_height - 1) { - h_high = h_low = cur_height - 1; - map_h = static_cast(h_low); - } else { - h_high = h_low + 1; - } - - if (w_low >= cur_width - 1) { - w_high = w_low = cur_width - 1; - map_w = static_cast(w_low); - } else { - w_high = w_low + 1; - } + int h_low = std::max(static_cast(floorf(map_h)), 0); + int w_low = std::max(static_cast(floorf(map_w)), 0); + int h_high = with_bi_pad ? h_low + 1 : std::min(static_cast(ceilf(map_h)), cur_height - 1); + int w_high = with_bi_pad ? w_low + 1 : std::min(static_cast(ceilf(map_w)), cur_width - 1); float lh = map_h - h_low; float lw = map_w - w_low; float hh = 1 - lh, hw = 1 - lw; - float v1 = data_im_ptr[h_low * src_strides[2] + w_low * src_strides[3]]; - float v2 = data_im_ptr[h_low * src_strides[2] + w_high * src_strides[3]]; - float v3 = data_im_ptr[h_high * src_strides[2] + w_low * src_strides[3]]; - float v4 = data_im_ptr[h_high * src_strides[2] + w_high * src_strides[3]]; + float v1 = (w_low >= 0 && h_low >= 0) ? data_im_ptr[h_low * src_strides[2] + w_low * src_strides[3]] : 0.0f; + float v2 = (w_high < cur_width && h_low >= 0) ? data_im_ptr[h_low * src_strides[2] + w_high * src_strides[3]] : 0.0f; + float v3 = (w_low >= 0 && h_high < cur_height) ? data_im_ptr[h_high * src_strides[2] + w_low * src_strides[3]] : 0.0f; + float v4 = (w_high < cur_width && h_high < cur_height) ? data_im_ptr[h_high * src_strides[2] + w_high * src_strides[3]] : 0.0f; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + float modulation_scalar = 1.0f; + + if (modulation_offset_ptr != nullptr) { + size_t modulation_index = (kh * KW + kw) * modulation_strides[1] + oh * modulation_strides[2] + ow * modulation_strides[3]; + modulation_scalar = modulation_offset_ptr[modulation_index]; + } + + const float weight = with_groups ? weights[g * wei_strides[0] + oc * wei_strides[1] + ic * wei_strides[2] + kh * wei_strides[3] + + kw * wei_strides[4]] + : weights[oc * wei_strides[0] + ic * wei_strides[1] + kh * wei_strides[2] + kw * wei_strides[3]]; + d += val * weight * modulation_scalar; } - d += val * (with_groups ? weights[g * wei_strides[0] + oc * wei_strides[1] + ic * wei_strides[2] + kh * wei_strides[3] + - kw * wei_strides[4]] - : weights[oc * wei_strides[0] + ic * wei_strides[1] + kh * wei_strides[2] + kw * wei_strides[3]]); } } } @@ -1024,7 +1195,8 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f void MKLDNNDeformableConvolutionNode::executeOptimized(const float* src, const float* offsets, const float* weights, float* dst, const std::vector& src_strides, const std::vector& off_strides, - const std::vector& dst_strides) { + const std::vector& dst_strides, const float* modulation, + const std::vector& modulation_strides) { size_t buffer_size = (size_t)jcp.nthr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic * jcp.typesize_in; std::vector input_buffer(buffer_size, 0); float* input_buffer_ptr = &input_buffer[0]; @@ -1040,6 +1212,11 @@ void MKLDNNDeformableConvolutionNode::executeOptimized(const float* src, const f par_conv.src = &src[n * src_strides[0] + _ic*jcp.ic_block * src_strides[1] + (oh * jcp.stride_h - jcp.t_pad) * src_strides[2] - jcp.l_pad * src_strides[3]]; par_conv.off = &offsets[n * off_strides[0] + oh * off_strides[2]]; + if (modulation != nullptr) { + par_conv.modulation = &modulation[n * modulation_strides[0] + oh * modulation_strides[2]]; + } else { + par_conv.modulation = nullptr; + } par_conv.filt = weights; par_conv.dst = &dst[n * dst_strides[0] + _oc*jcp.oc_block * dst_strides[1] + oh * dst_strides[2]]; @@ -1052,6 +1229,8 @@ void MKLDNNDeformableConvolutionNode::executeOptimized(const float* src, const f } void MKLDNNDeformableConvolutionNode::execute(mkldnn::stream strm) { + const size_t inputsNumber = getOriginalInputsNumber(); + auto &srcMemory0 = getParentEdgeAt(0)->getMemory(); auto &srcMemory1 = getParentEdgeAt(1)->getMemory(); auto &srcMemory2 = getParentEdgeAt(2)->getMemory(); @@ -1060,6 +1239,11 @@ void MKLDNNDeformableConvolutionNode::execute(mkldnn::stream strm) { const auto *src = reinterpret_cast(srcMemory0.GetPtr()); const auto *offsets = reinterpret_cast(srcMemory1.GetPtr()); const auto *weights = reinterpret_cast(srcMemory2.GetPtr()); + float* modulation = nullptr; + if (inputsNumber > 3) { + modulation = reinterpret_cast(getParentEdgeAt(3)->getMemory().GetPtr()); + } + float *dst = reinterpret_cast(dstMemory.GetPtr()); auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor(); @@ -1079,13 +1263,19 @@ void MKLDNNDeformableConvolutionNode::execute(mkldnn::stream strm) { dst_strides[dst_block_desc.getOrder()[i]] = dst_block_desc.getStrides()[i]; } + auto off_strides = config.inConfs[1].desc.getBlockingDesc().getStrides(); auto wei_strides = config.inConfs[2].desc.getBlockingDesc().getStrides(); + InferenceEngine::SizeVector modulation_strides; + if (inputsNumber > 3) { + modulation_strides = config.inConfs[3].desc.getBlockingDesc().getStrides(); + } + if (def_conv_kernel) { - executeOptimized(src, offsets, weights, dst, src_strides, off_strides, dst_strides); + executeOptimized(src, offsets, weights, dst, src_strides, off_strides, dst_strides, modulation, modulation_strides); } else { - executeReference(src, offsets, weights, dst, src_strides, off_strides, wei_strides, dst_strides); + executeReference(src, offsets, weights, dst, src_strides, off_strides, wei_strides, dst_strides, modulation, modulation_strides); } } diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.h index e74e49788ccda6..5ffaff1fb3a316 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.h @@ -22,8 +22,6 @@ struct jit_def_conv_params { int kd, kh, kw; int stride_d, stride_h, stride_w; int dilate_d, dilate_h, dilate_w; - bool with_bias; - bool with_sum; int nthr; int nb_ic, ic_block; int nb_oc, oc_block; @@ -32,13 +30,19 @@ struct jit_def_conv_params { int ur_w_tail; int typesize_in; int typesize_off; + int typesize_modulation; int typesize_bia; int typesize_out; + bool with_bias; + bool with_sum; + bool with_modulation; + bool with_bi_pad; }; struct jit_def_conv_call_args { const void *src; const void *off; + const void *modulation; const void *filt; const void *bias; const void *dst; @@ -80,6 +84,7 @@ class MKLDNNDeformableConvolutionNode : public MKLDNNNode { private: size_t group = 1; + bool with_bilinear_pad = false; std::vector stride = {}; std::vector dilation = {}; std::vector paddingL = {}; @@ -92,10 +97,11 @@ class MKLDNNDeformableConvolutionNode : public MKLDNNNode { void executeReference(const float* src, const float* offsets, const float* weights, float* dst, const std::vector& src_strides, const std::vector& off_strides, - const std::vector& wei_strides, const std::vector& dst_strides); + const std::vector& wei_strides, const std::vector& dst_strides, + const float* modulation = nullptr, const std::vector& modulation_strides = {}); void executeOptimized(const float* src, const float* offsets, const float* weights, float* dst, - const std::vector& src_strides, const std::vector& off_strides, - const std::vector& dst_strides); + const std::vector& src_strides, const std::vector& off_strides, const std::vector& dst_strides, + const float* modulation = nullptr, const std::vector& modulation_strides = {}); }; } // namespace MKLDNNPlugin From ece1d3791791e1bb6fb7fdf35194691825e2a68f Mon Sep 17 00:00:00 2001 From: Nikolay Shchegolev Date: Fri, 16 Jul 2021 16:03:02 +0300 Subject: [PATCH 2/2] Some fixes --- .../nodes/mkldnn_def_conv_node.cpp | 39 +++++++++++++------ .../deformable_convolution.cpp | 15 ++++--- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp index 9e1b2fe3c555d3..157252681c25c7 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp @@ -895,9 +895,10 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ bool MKLDNNDeformableConvolutionNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - const auto defConvNode = ngraph::as_type_ptr(op); - if (!defConvNode) { - errorMessage = "Node is not an instance of DeformableConvolution form the operation set v8."; + if (!one_of(op->get_type_info(), + ngraph::op::v1::DeformableConvolution::type_info, + ngraph::op::v8::DeformableConvolution::type_info)) { + errorMessage = "Node is not an instance of DeformableConvolution form the operation set v1 or v8."; return false; } } catch (...) { @@ -913,22 +914,28 @@ MKLDNNDeformableConvolutionNode::MKLDNNDeformableConvolutionNode(const std::shar if (!isSupportedOperation(op, errorMessage)) { IE_THROW(NotImplemented) << errorMessage; } - auto defConvNode = ngraph::as_type_ptr(op); + auto defConvNodeBase = std::dynamic_pointer_cast(op); - group = defConvNode->get_group(); - deformable_group = defConvNode->get_deformable_group(); - with_bilinear_pad = defConvNode->get_use_bilinear_interpolation_padding(); - auto& strides = defConvNode->get_strides(); + group = defConvNodeBase->get_group(); + deformable_group = defConvNodeBase->get_deformable_group(); + auto& strides = defConvNodeBase->get_strides(); for (int i = 0; i < strides.size(); i++) { stride.push_back(strides[i]); } - auto& dilations = defConvNode->get_dilations(); + auto& dilations = defConvNodeBase->get_dilations(); for (int i = 1; i <= dilations.size(); i++) { dilation.push_back(dilations[dilations.size() - i] - 1); } - paddingL = defConvNode->get_pads_begin(); + paddingL = defConvNodeBase->get_pads_begin(); + + if (op->get_type_info() == ngraph::op::v8::DeformableConvolution::type_info) { + auto defConvNode = std::dynamic_pointer_cast(op); + with_bilinear_pad = defConvNode->get_bilinear_interpolation_pad(); + } else { + with_bilinear_pad = false; + } } void MKLDNNDeformableConvolutionNode::getSupportedDescriptors() { @@ -999,7 +1006,17 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() { config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, dataFormat); config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, offFormat); - config.inConfs[2].desc = MKLDNNMemoryDesc(getParentEdgeAt(2)->getDims(), memory::data_type::f32, weiFormat); + auto& wDims = getParentEdgeAt(2)->getDims(); + if (group > 1 && wDims.ndims() != 5) { + auto old_dims = wDims.ToSizeVector(); + auto new_dims = InferenceEngine::SizeVector({group, div_up(old_dims[0], group)}); + for (int i = 1; i < old_dims.size(); i++) { + new_dims.push_back(old_dims[i]); + } + config.inConfs[2].desc = MKLDNNMemoryDesc(MKLDNNDims(new_dims), memory::data_type::f32, weiFormat); + } else { + config.inConfs[2].desc = MKLDNNMemoryDesc(getParentEdgeAt(2)->getDims(), memory::data_type::f32, weiFormat); + } if (inputsNumber > 3) { config.inConfs[3].desc = MKLDNNMemoryDesc(getParentEdgeAt(3)->getDims(), memory::data_type::f32, memory::format_tag::nchw); } diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/deformable_convolution.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/deformable_convolution.cpp index 437d8737d65bf7..83f36819d800b9 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/deformable_convolution.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/deformable_convolution.cpp @@ -103,13 +103,18 @@ const std::vector single_deform_groups = {3}; const auto deformableConv2DParams_SingleTestCase = ::testing::Combine( ::testing::ValuesIn(single_deform_vals), - ::testing::ValuesIn(single_kernel), ::testing::ValuesIn(strides), - ::testing::ValuesIn(padBegins), ::testing::ValuesIn(padEnds), - ::testing::ValuesIn(dilations), ::testing::ValuesIn(groups), - ::testing::ValuesIn(single_deform_groups), ::testing::ValuesIn(numOutChannels), + ::testing::ValuesIn(single_kernel), + ::testing::ValuesIn(strides), + ::testing::ValuesIn(padBegins), + ::testing::ValuesIn(padEnds), + ::testing::ValuesIn(dilations), + ::testing::ValuesIn(groups), + ::testing::ValuesIn(single_deform_groups), + ::testing::ValuesIn(numOutChannels), ::testing::Values(ngraph::op::PadType::EXPLICIT), ::testing::ValuesIn(with_bilinear_interpolation_pad), - ::testing::ValuesIn(with_modulated_scalar)); + ::testing::ValuesIn(with_modulated_scalar) +); INSTANTIATE_TEST_SUITE_P( smoke_DeformableConvolution2D_SingleTestCase, DeformableConvolutionLayerTest,