Skip to content

Commit

Permalink
[HotFix] 3D Group Conv Backward data and weight update. Failure notic…
Browse files Browse the repository at this point in the history
…ed when pads and strides are not 1 (#2560)
  • Loading branch information
amberhassaan authored Dec 14, 2023
1 parent ffe4f8d commit 709e289
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ struct CKArgs
Do = ProblemInterpreter::GetOutputDepthDo(problem);
Z = ProblemInterpreter::GetFilterDepthZ(problem);

input = {G, N, C, Di, Hi, Wi};
output = {G, N, K, Do, Ho, Wo};
// On a backward pass, out is in and in is out and this is silly
output = {G, N, C, Di, Hi, Wi};
input = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// miopen strides to CK strides
Expand Down
5 changes: 3 additions & 2 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ struct CKArgs
Do = ProblemInterpreter::GetOutputDepthDo(problem);
Z = ProblemInterpreter::GetFilterDepthZ(problem);

input = {G, N, C, Di, Hi, Wi};
output = {G, N, K, Do, Ho, Wo};
// On a backward pass, out is in and in is out and this is silly
output = {G, N, C, Di, Hi, Wi};
input = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// miopen strides to CK strides
Expand Down
5 changes: 4 additions & 1 deletion test/gtest/group_conv3d_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@

std::vector<Conv3DTestCase> ConvTestConfigs()
{ // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z
return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
return {{1, 1, 4, 14, 28, 28, 4, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 1, 4, 4, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1, miopenConvolution},
{1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{32, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{16, 128, 16, 28, 28, 28, 16, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
Expand Down
5 changes: 4 additions & 1 deletion test/gtest/group_conv3d_wrw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@

std::vector<Conv3DTestCase> ConvTestConfigs()
{ // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z
return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
return {{1, 1, 4, 14, 28, 28, 4, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 1, 4, 4, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, miopenConvolution},
{1, 1, 1, 8, 8, 8, 1, 2, 2, 2, 0, 0, 0, 2, 2, 2, 1, 1, 1, miopenConvolution},
{1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{2, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
{32, 128, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution},
Expand Down

0 comments on commit 709e289

Please sign in to comment.