Skip to content

Commit

Permalink
Layernorm and groupnorm support to save mean and inverse std in forwa…
Browse files Browse the repository at this point in the history
…rd (#929)

* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
  • Loading branch information
rocking5566 authored Oct 18, 2023
1 parent 58338bb commit 3696fe1
Show file tree
Hide file tree
Showing 38 changed files with 1,394 additions and 545 deletions.
42 changes: 33 additions & 9 deletions client_example/05_layernorm/layernorm2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"

using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

#define SAVE_MEAN_INV_STD

constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
Expand Down Expand Up @@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M);
#endif

using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim>;
Expand Down Expand Up @@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{0, 1}, // gammaStrides
{0, 1}, // betaStrides
{Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims
1e-4,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
PassThrough{});

auto invoker_ptr = op_ptr->MakeInvokerPointer();
Expand All @@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;

#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * M * 2;
#endif

float gb_per_sec = num_byte / 1.E6 / ave_time;

std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
Expand Down Expand Up @@ -140,17 +157,24 @@ int main(int argc, char* argv[])

auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides
{1}, // gammaStrides
{1}, // betaStrides
{0, 1}, // gammaStrides
{0, 1}, // betaStrides
{Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims
1e-4,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
PassThrough{});

auto invoker_ptr = op_ptr->MakeInvokerPointer();
Expand Down
122 changes: 78 additions & 44 deletions client_example/18_groupnorm/groupnorm_swish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"

using XDataType = ck::half_t;
using GammaDataType = float;
using BetaDataType = float;
using YDataType = ck::half_t;
using ComputeDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish;
using XDataType = ck::half_t;
using GammaDataType = float;
using BetaDataType = float;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish;

#define SAVE_MEAN_INV_STD

constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
Expand Down Expand Up @@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C;

std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> save_mean_inv_std_strides = {G, 1};

SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
#endif

using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
Swish,
Rank,
NumReduceDim>;
Expand All @@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const auto& generic_op_ptr = op_ptrs[0];

auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});

if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
Expand All @@ -107,21 +121,29 @@ int main(int argc, char* argv[])

for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});

auto invoker_ptr = op_ptr->MakeInvokerPointer();

Expand All @@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;

#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2;
#endif

float gb_per_sec = num_byte / 1.E6 / ave_time;

std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
Expand Down Expand Up @@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;

auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});

auto invoker_ptr = op_ptr->MakeInvokerPointer();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType,
HDataType,
AccDataType,
AccDataType,
HElementOp,
2,
1>;

Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});

auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker();
Expand All @@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();

auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon);
e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument);
}

Expand Down
19 changes: 12 additions & 7 deletions example/27_layernorm/layernorm_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

#include "common.hpp"

using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

#define SAVE_MEAN_INV_STD

constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
Expand All @@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim,
Expand All @@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc"

int main() { return run_groupnorm_example<DeviceInstance>(); }
Loading

0 comments on commit 3696fe1

Please sign in to comment.