Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused solver for Fwd Convolution with Residual add, Bias add and then activation function #2517

Merged
merged 24 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
896aae1
WIP: fused solver for residual add
amberhassaan Nov 8, 2023
0c6ded7
Merge branch 'develop' into amber/fused-res-add-solver
junliume Nov 10, 2023
44443b5
WIP: updating the solver to use the correct CK kernel
amberhassaan Nov 28, 2023
dd06505
basic code structure for new fusion solver completed and compiles
amberhassaan Dec 3, 2023
be6f4e5
Merge remote-tracking branch 'origin/develop' into amber/fused-res-ad…
amberhassaan Dec 3, 2023
7b043d7
formatting
amberhassaan Dec 3, 2023
bf88d8a
Merge branch 'develop' into amber/fused-res-add-solver
amberhassaan Dec 13, 2023
3539871
add support for tensor scale add op
amberhassaan Dec 13, 2023
7545e7f
WIP: added gtest for the fused solver
amberhassaan Dec 13, 2023
8c49752
fix formatting
amberhassaan Dec 13, 2023
c64b7e5
bump new ck commint hash
iq136boy Dec 14, 2023
fea7680
modified fusion APIs and invoke parameters
iq136boy Dec 14, 2023
cfa8fa1
fix bugs in new fusion solver
iq136boy Dec 14, 2023
fa79a8a
fix bugs and errors in gtest
iq136boy Dec 15, 2023
59dc084
address comments
iq136boy Dec 15, 2023
488dae1
Merge remote-tracking branch 'origin/develop' into amber/fused-res-ad…
amberhassaan Dec 16, 2023
221636b
WIP: fused solver is running. Debugging assert failure
amberhassaan Dec 17, 2023
1154f9f
bug fixes. Code works
amberhassaan Dec 18, 2023
f5b1667
Merge remote-tracking branch 'origin/develop' into amber/fused-res-ad…
amberhassaan Dec 18, 2023
e726221
formatting fixes
amberhassaan Dec 18, 2023
669e6f6
fix tidy error
amberhassaan Dec 18, 2023
4aa97df
diabling broken test
amberhassaan Dec 18, 2023
dc942ee
added namespace around gtest
amberhassaan Dec 18, 2023
e5786d4
Merge branch 'develop' into amber/fused-res-add-solver
amberhassaan Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ set( MIOpen_Source
solver/conv_bin_winoRxS_fused.cpp
solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp
solver/conv_ck_igemm_fwd_bias_activ_fused.cpp
solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp
solver/conv_direct_naive_conv.cpp
solver/conv_direct_naive_conv_bwd.cpp
solver/conv_direct_naive_conv_fwd.cpp
Expand Down
58 changes: 46 additions & 12 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ miopenStatus_t ConvBiasActivFusion(Handle& handle,
assert(workspaceSizeInBytes == 0);
std::ignore = workspace;
std::ignore = workspaceSizeInBytes;
/// \todo: add workspace support in fusion

/*
if(alpha1 != nullptr)
{
const auto falpha1 = *(static_cast<const float*>(alpha1));
Expand All @@ -92,29 +95,44 @@ miopenStatus_t ConvBiasActivFusion(Handle& handle,
if(falpha2 != 1.0f)
MIOPEN_THROW(miopenStatusNotImplemented, "alpha2 can only be 1.0");
}
if(z != nullptr || zDesc.GetSize() != 0)
MIOPEN_THROW(miopenStatusNotImplemented, "The addition of z vector is not yet supported");
*/

float falpha1 = alpha1 ? *(static_cast<const double*>(alpha1)) : 1.0f;
float falpha2 = alpha2 ? *(static_cast<const double*>(alpha2)) : 1.0f;

// if(z != nullptr || zDesc.GetSize() != 0)
// MIOPEN_THROW(miopenStatusNotImplemented, "The addition of z vector is not yet supported");
FusionPlanDescriptor fusePlanDesc{miopenVerticalFusion, xDesc};
OperatorArgs fusionArgs;
auto convoOp = std::make_shared<ConvForwardOpDescriptor>(conv_desc, wDesc);
auto convOp = std::make_shared<ConvForwardOpDescriptor>(conv_desc, wDesc);
auto zOp = std::make_shared<TensorScaleAddOpDescriptor>(zDesc);
auto biasOp = std::make_shared<BiasFusionOpDescriptor>(biasDesc);
auto activOp = std::make_shared<ActivFwdFusionOpDescriptor>(activationDesc.GetMode());
MIOPEN_CHECK(fusePlanDesc.AddOp(convoOp));

if(activationDesc.GetMode() != miopenActivationRELU)
{
MIOPEN_THROW(miopenStatusNotImplemented,
"only Activation Mode == miopenActivationRELU is supported");
}

MIOPEN_CHECK(fusePlanDesc.AddOp(convOp));
MIOPEN_CHECK(fusePlanDesc.SetConvAlgo(algo));
MIOPEN_CHECK(fusePlanDesc.AddOp(zOp));
MIOPEN_CHECK(fusePlanDesc.AddOp(biasOp));
MIOPEN_CHECK(fusePlanDesc.AddOp(activOp));

MIOPEN_CHECK(fusePlanDesc.Compile(handle));
float alpha = static_cast<float>(1.0);
float beta = static_cast<float>(0);
float alpha = 1.0f;
float beta = 0.0f;
float activ_alpha = activationDesc.GetAlpha();
float activ_beta = activationDesc.GetBeta();
float activ_gamma = activationDesc.GetGamma();

// Set the Args
MIOPEN_CHECK(convoOp->SetArgs(fusionArgs, &alpha, &beta, w));
MIOPEN_CHECK(activOp->SetArgs(fusionArgs, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
MIOPEN_CHECK(convOp->SetArgs(fusionArgs, falpha1, beta, w));
MIOPEN_CHECK(zOp->SetArgs(fusionArgs, falpha2, z));
MIOPEN_CHECK(biasOp->SetArgs(fusionArgs, &alpha, &beta, bias));
MIOPEN_CHECK(activOp->SetArgs(fusionArgs, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
MIOPEN_CHECK(fusePlanDesc.Execute(handle, xDesc, x, yDesc, y, fusionArgs));
return miopenStatusSuccess;
}
Expand Down Expand Up @@ -513,11 +531,11 @@ miopenStatus_t ConvForwardOpDescriptor::GetOutputDesc(TensorDescriptor& output_d
}

miopenStatus_t ConvForwardOpDescriptor::SetArgs(OperatorArgs& args,
const void* /*alpha*/,
const void* /*beta*/,
float alpha,
float beta,
ConstData_t w)
{
auto op_args = std::make_unique<fusion::ConvolutionOpInvokeParam>(w);
auto op_args = std::make_unique<fusion::ConvolutionOpInvokeParam>(alpha, beta, w);
args.SetArg(GetIdx(), std::move(op_args));
return miopenStatusSuccess;
}
Expand Down Expand Up @@ -672,6 +690,21 @@ miopenStatus_t BiasFusionOpDescriptor::SetArgs(OperatorArgs& args,
return miopenStatusSuccess;
}

miopenStatus_t TensorScaleAddOpDescriptor::GetOutputDesc(TensorDescriptor& output_desc) const
{
output_desc = this->tensor_desc;
return miopenStatusSuccess;
}

miopenStatus_t TensorScaleAddOpDescriptor::SetArgs(OperatorArgs& args,
float alpha,
ConstData_t tensor_ptr)
{
auto op_args = std::make_unique<fusion::TensorScaleAddOpInvokeParam>(alpha, tensor_ptr);
args.SetArg(GetIdx(), std::move(op_args));
return miopenStatusSuccess;
}

std::string FusionPlanDescriptor::GetAlgorithmName(const Handle& /*handle*/)
{
if(conv_fwd_algo)
Expand All @@ -698,7 +731,8 @@ static auto GetFusedDirectSolvers()

static auto GetFusedIGemmSolvers()
{
return solver::SolverContainer<solver::fusion::ConvCKIgemmFwdBiasActivFused>{};
return solver::SolverContainer<solver::fusion::ConvCKIgemmFwdBiasActivFused,
solver::fusion::ConvCKIgemmFwdBiasResAddActivFused>{};
}

static auto GetFusedWinogradSolvers()
Expand Down
13 changes: 12 additions & 1 deletion src/include/miopen/fusion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ struct BiasFusionOpDescriptor : FusionOpDescriptor
TensorDescriptor base_desc;
};

struct TensorScaleAddOpDescriptor: public FusionOpDescriptor {
TensorScaleAddOpDescriptor(const TensorDescriptor& desc) : tensor_desc(desc) {}
miopenStatus_t GetOutputDesc(TensorDescriptor& output_desc) const override;
miopenStatus_t GetNetworkConfig(std::ostringstream& network_config) override;
miopenStatus_t
SetArgs(OperatorArgs& args, float alpha, ConstData_t tensor_ptr);
miopenFusionOp_t kind() const override { return miopenFusionOpTensorScaleAdd; };
TensorDescriptor tensor_desc;

};

struct ActivFwdFusionOpDescriptor : FusionOpDescriptor
{
ActivFwdFusionOpDescriptor(miopenActivationMode_t mode) : activMode(mode) {}
Expand Down Expand Up @@ -214,7 +225,7 @@ struct ConvForwardOpDescriptor : FusionOpDescriptor
kernel_info_valid(false),
conv_compiler_options(""){};
miopenStatus_t GetOutputDesc(TensorDescriptor& output_desc) const override;
miopenStatus_t SetArgs(OperatorArgs& args, const void* alpha, const void* beta, ConstData_t w);
miopenStatus_t SetArgs(OperatorArgs& args, float alpha, float beta, ConstData_t w);
miopenStatus_t GetNetworkConfig(std::ostringstream& network_config) override;
bool isASMApplicable(Handle& handle);
miopenFusionOp_t kind() const override { return miopenFusionOpConvForward; };
Expand Down
14 changes: 13 additions & 1 deletion src/include/miopen/fusion/fusion_invoke_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ struct FusionOpInvokeParamBase

struct ConvolutionOpInvokeParam : FusionOpInvokeParamBase
{
ConvolutionOpInvokeParam(ConstData_t w) : weights(w) {}
ConvolutionOpInvokeParam(float _alpha, float _beta, ConstData_t w) :
alpha(_alpha),
beta(_beta),
weights(w) {}
float alpha = 1.0f; // scales new result of convolution
float beta = 0.0f; // scales old val of convolution output tensor
ConstData_t weights = nullptr;
};

Expand All @@ -50,6 +55,13 @@ struct BiasOpInvokeParam : FusionOpInvokeParamBase
ConstData_t bdata = nullptr;
};

struct TensorScaleAddOpInvokeParam : public FusionOpInvokeParamBase
{
TensorScaleAddOpInvokeParam(float a, ConstData_t tp) : alpha(a), tensor_ptr(tp) {}
float alpha = 1.0f;
ConstData_t tensor_ptr = nullptr;
};

struct ActivationOpInvokeParam : FusionOpInvokeParamBase
{
ActivationOpInvokeParam(double alpha, double beta, double gamma)
Expand Down
67 changes: 67 additions & 0 deletions src/include/miopen/fusion/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,73 @@ struct ConvCKIgemmFwdBiasActivFused final
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

struct PerfConfigConvCKIgemmFwdBiasResAddActivFused
: PerfConfigBase<PerfConfigConvCKIgemmFwdBiasResAddActivFused>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerfConfigConvCKIgemmFwdBiasResAddActivFused(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerfConfigConvCKIgemmFwdBiasResAddActivFused()
: PerfConfigConvCKIgemmFwdBiasResAddActivFused(0, "")
{
}
PerfConfigConvCKIgemmFwdBiasResAddActivFused(bool)
: PerfConfigConvCKIgemmFwdBiasResAddActivFused(0, "")
{
}
void HeuristicInit(const FusionDescription& fdesc_problem);
bool SetNextValue(const FusionDescription& fdesc_problem);
bool IsValidValue() const;
bool IsValid(const FusionContext&, const FusionDescription& fdesc_problem) const;

template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
bool operator==(const PerfConfigConvCKIgemmFwdBiasResAddActivFused& other) const;

private:
template <typename DataType, typename AccumDataType = DataType>
void Init(const miopen::conv::ProblemDescription&);
template <typename DataType, typename AccumDataType = DataType>
bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const;
};

struct ConvCKIgemmFwdBiasResAddActivFused final
: FusionTunableSolver<PerfConfigConvCKIgemmFwdBiasResAddActivFused>
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvCKIgemmFwdBiasResAddActivFused>();
}

PerfConfigConvCKIgemmFwdBiasResAddActivFused
GetDefaultPerformanceConfig(const FusionContext& ctx,
const FusionDescription& fdesc_problem) const override;
bool IsValidPerformanceConfig(
const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const PerfConfigConvCKIgemmFwdBiasResAddActivFused& config) const override;
PerfConfigConvCKIgemmFwdBiasResAddActivFused
Search(const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const AnyInvokeParams& invoke_ctx) const override;
bool IsApplicable(const FusionContext& ctx,
const FusionDescription& fdesc_problem) const override;
ConvSolution
GetSolution(const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const PerfConfigConvCKIgemmFwdBiasResAddActivFused& config) const override;

private:
template <typename DataType, typename AccumDataType = DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};
struct ConvBinWinogradRxSFused final : FusionSolverBase
{
const std::string& SolverDbId() const override
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/fusion_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum miopenFusionOp_t
miopenFusionOpBatchNormFwdTrain = 4,
miopenFusionOpBatchNormBwdTrain = 5,
miopenFusionOpActivBackward = 6,
miopenFusionOpTensorScaleAdd = 7,
};

enum MDGraph_op_t
Expand Down
6 changes: 6 additions & 0 deletions src/ocl/fusionopbiasbnactivocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ miopenStatus_t BiasFusionOpDescriptor::GetNetworkConfig(std::ostringstream& netw
return miopenStatusSuccess;
}

miopenStatus_t TensorScaleAddOpDescriptor::GetNetworkConfig(std::ostringstream& network_config)
{
network_config << "tensorScaleAdd"; // for bias
return miopenStatusSuccess;
}

miopenStatus_t ActivFwdFusionOpDescriptor::GetNetworkConfig(std::ostringstream& network_config)
{
network_config << "ActivFwd" << std::to_string(activMode);
Expand Down
5 changes: 5 additions & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
Primitive::Fusion,
fusion::ConvCKIgemmFwdBiasActivFused{}.SolverDbId(),
miopenConvolutionAlgoImplicitGEMM);
Register(registry,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New solvers must be added to the end of the list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

++id,
Primitive::Fusion,
fusion::ConvCKIgemmFwdBiasResAddActivFused{}.SolverDbId(),
miopenConvolutionAlgoImplicitGEMM);
Register(registry, ++id, Primitive::Pooling, pooling::PoolingForwardNaive{}.SolverDbId());
RegisterWithSolver(registry,
++id,
Expand Down
Loading
Loading