From 709e28990dd85ac339a4175140554900ba3802ab Mon Sep 17 00:00:00 2001 From: amberhassaan Date: Thu, 14 Dec 2023 16:43:14 -0500 Subject: [PATCH] [HotFix] 3D Group Conv Backward data and weight update. Failure noticed when pads and strides are not 1 (#2560) --- src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp | 5 +++-- src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp | 5 +++-- test/gtest/group_conv3d_bwd.hpp | 5 ++++- test/gtest/group_conv3d_wrw.hpp | 5 ++++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index 1a04669b33..fadf4c194e 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -86,8 +86,9 @@ struct CKArgs Do = ProblemInterpreter::GetOutputDepthDo(problem); Z = ProblemInterpreter::GetFilterDepthZ(problem); - input = {G, N, C, Di, Hi, Wi}; - output = {G, N, K, Do, Ho, Wo}; + // On a backward pass, out is in and in is out and this is silly + output = {G, N, C, Di, Hi, Wi}; + input = {G, N, K, Do, Ho, Wo}; weight = {G, K, C, Z, Y, X}; // miopen strides to CK strides diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 543bb45592..deb8eb14d3 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -84,8 +84,9 @@ struct CKArgs Do = ProblemInterpreter::GetOutputDepthDo(problem); Z = ProblemInterpreter::GetFilterDepthZ(problem); - input = {G, N, C, Di, Hi, Wi}; - output = {G, N, K, Do, Ho, Wo}; + // On a backward pass, out is in and in is out and this is silly + output = {G, N, C, Di, Hi, Wi}; + input = {G, N, K, Do, Ho, Wo}; weight = {G, K, C, Z, Y, X}; // miopen strides to CK strides diff --git a/test/gtest/group_conv3d_bwd.hpp b/test/gtest/group_conv3d_bwd.hpp index 5653a18138..c59b48c500 100644 --- a/test/gtest/group_conv3d_bwd.hpp +++ b/test/gtest/group_conv3d_bwd.hpp @@ -29,7 +29,10 @@ std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z - return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + return {{1, 1, 4, 14, 28, 28, 4, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 1, 4, 4, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1, miopenConvolution}, {1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {32, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {16, 128, 16, 28, 28, 28, 16, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, diff --git a/test/gtest/group_conv3d_wrw.hpp b/test/gtest/group_conv3d_wrw.hpp index a0e504a000..647004e45a 100644 --- a/test/gtest/group_conv3d_wrw.hpp +++ b/test/gtest/group_conv3d_wrw.hpp @@ -29,7 +29,10 @@ std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z - return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + return {{1, 1, 4, 14, 28, 28, 4, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 1, 4, 4, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1, miopenConvolution}, {1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {2, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {32, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},