diff --git a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc old mode 100755 new mode 100644 index b548427548..82ac0a15e1 --- a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc +++ b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc @@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) switch(init_method) { - case 0: break; - case 1: + case 0: break; + case 1: - a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; - default: - a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + default: + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; + break; } DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_re(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); // Intermediate Value For E Real and Img - DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); - + DeviceMem e_device_buf_re1(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img1(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); @@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) // set zero for intermediate values e_device_buf_re1.SetZero(); e_device_buf_img1.SetZero(); - + auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{alpha, beta}; @@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) // device operation // For real Intermediate Value re_1 - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), - b_device_buf_re.GetDeviceBuffer(), - std::array{d_device_buf_re.GetDeviceBuffer()}, - e_device_buf_re1.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument_re1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{d_device_buf_re.GetDeviceBuffer()}, + e_device_buf_re1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); if(!op.IsSupportedArgument(argument_re1)) { @@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); - alpha = -1.f; beta = 1.f; @@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) // For real Intermediate Value re_2 // auto op = DeviceOpInstance{}; // auto invoker = op.MakeInvoker(); - auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), - b_device_buf_img.GetDeviceBuffer(), - std::array{e_device_buf_re1.GetDeviceBuffer()}, - e_device_buf_re.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); + auto argument_re2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{e_device_buf_re1.GetDeviceBuffer()}, + e_device_buf_re.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); if(!op.IsSupportedArgument(argument_re2)) { @@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); - alpha = 1.f; beta = 1.f; @@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) b_element_op = BElementOp{}; cde_element_op = CDEElementOp{alpha, beta}; - auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), - b_device_buf_img.GetDeviceBuffer(), - std::array{d_device_buf_img.GetDeviceBuffer()}, - e_device_buf_img1.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - + auto argument_img1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{d_device_buf_img.GetDeviceBuffer()}, + e_device_buf_img1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); if(!op.IsSupportedArgument(argument_img1)) { @@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) alpha = 1.f; beta = 1.f; - auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), - b_device_buf_re.GetDeviceBuffer(), - std::array{e_device_buf_img1.GetDeviceBuffer()}, - e_device_buf_img.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - + auto argument_img2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{e_device_buf_img1.GetDeviceBuffer()}, + e_device_buf_img.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); if(!op.IsSupportedArgument(argument_img2)) { @@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); - ck::index_t M = ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); @@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; - float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; + float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1; - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " @@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); auto isRealOk = 0; - auto isImgOk = 0; + auto isImgOk = 0; if(do_verification) { @@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) auto ref_op = ReferenceOpInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - auto ref_argument_re = - ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); + auto ref_argument_re = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); ref_invoker.Run(ref_argument_re); alpha = 1.f; beta = 1.f; - + cde_element_op = CDEElementOp{alpha, beta}; - for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) { for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) @@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) alpha = 1.f; beta = -1.f; - + cde_element_op = CDEElementOp{alpha, beta}; - auto ref_argument_re1 = - ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); + auto ref_argument_re1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); ref_invoker.Run(ref_argument_re1); @@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) } } - isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; - - - + isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; // Img Part Verification Tensor c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); - auto ref_argument_img = - ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); - + auto ref_argument_img = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); + ref_invoker.Run(ref_argument_img); alpha = 1.f; beta = 1.f; - + cde_element_op = CDEElementOp{alpha, beta}; for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) @@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) } } - auto ref_argument_img1 = - ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); - + auto ref_argument_img1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); + ref_invoker.Run(ref_argument_img1); for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) @@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) } } - isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; return (isRealOk && isImgOk); } diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index bac5f45cd3..feae5f791d 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,4 +1,21 @@ +set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) -target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file +message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 66b16c1b7f..405325a2a1 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm ``` # in the root of ck_tile mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_layernorm2d_fwd -j ``` This will result in an executable `build/bin/tile_example_layernorm2d_fwd` @@ -20,4 +19,4 @@ args: -e epsilon (default:1e-5) -v cpu validation or not (default:1) -prec precision (default:fp16) -``` \ No newline at end of file +``` diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp new file mode 100644 index 0000000000..f2f51de5d9 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" + +template +using trait_ = layernorm2d_fwd_traits_; + +template +float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ +#if 1 + float r = -1; + // clang-format off + // rm rn tm tn vn pd mv 2p + if(a.n <= 64) { + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = layernorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = layernorm2d_fwd_>(s, a); + else + r = layernorm2d_fwd_>(s, a); + } + return r; +#else + return layernorm2d_fwd_>(s, a); +#endif + // clang-format on +} + +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + + float r = -1; + if(t.data_type.compare("fp16") == 0) + { + return layernorm2d_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return layernorm2d_fwd_b16_(t, a, s); + } + if(r < 0) + throw std::runtime_error("Without supported instances!"); + + return r; +} diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..2a20d1e057 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +#if 0 +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +template float layernorm2d_fwd_>(const S&, A); +#endif + +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..d043efc86c --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..a6ffc8cd2f --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..80beeca67b --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..b362a550a0 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..9c2d78999c --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..c0c75f878b --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..1bcd0f8a7e --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..6b25fce8c2 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..c4400f0f24 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..7f0e4898cb --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +#if 0 +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +template float layernorm2d_fwd_>(const S&, A); +#endif + +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..8c3a42cc4f --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..04d8bc1533 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..c325747494 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..c71db57a6a --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..f3ca0932ef --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..242f1d2dd5 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..e3bfa8e3a4 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..90d960cf09 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..0960a95c31 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd mv 2p +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +template float layernorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp new file mode 100644 index 0000000000..22895e8edd --- /dev/null +++ b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp @@ -0,0 +1,67 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = layernorm2d_fwd_args; + +template +using trait_ = layernorm2d_fwd_traits_; + +template +float layernorm2d_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveMeanInvStd, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Layernorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 35f291e060..4f12d91032 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -2,161 +2,120 @@ #include "layernorm2d_fwd.hpp" #include -// Host API implementation -float layernorm2d_fwd(layernorm2d_fwd_traits t, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) +// different threshold for different dtype +template +auto get_elimit() { - if(t.data_type.compare("fp16") == 0) - { - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; - - using thread_tile = ck_tile::sequence<4, 4>; - using warp_tile = ck_tile::sequence<8, 128>; - using block_tile = ck_tile::sequence<32, 128>; - - using Shape = ck_tile::TileLayernorm2dShape; - - using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; - - using Kernel = ck_tile::Layernorm2dFwd; - - auto kargs = Kernel::MakeKargs( - a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); - - const dim3 grids = Kernel::GridSize(a.M); - constexpr dim3 blocks = Kernel::BlockSize(); - - constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - } + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} - return 0; +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); } auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") - .insert("n", "4096", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("e", "1e-5", "epsilon") + .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") - .insert("prec", "fp16", "precision"); + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -int main(int argc, char* argv[]) +template +bool run(const ck_tile::ArgParser& arg_parser) { - - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; float epsilon = arg_parser.get_float("e"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; + assert(stride >= n); - // host verify - ck_tile::HostTensor x_host({M, N}); - ck_tile::HostTensor gamma_host({N}); - ck_tile::HostTensor beta_host({N}); + using TypeConfig = LayerNormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + + using MeanDataType = + std::conditional_t; + using InvStdDataType = + std::conditional_t; - ck_tile::HostTensor y_host_ref({M, N}); - ck_tile::HostTensor y_host_dev({M, N}); + using ComputeDataType = typename TypeConfig::ComputeDataType; - ck_tile::HostTensor mean_host_ref({M}); - ck_tile::HostTensor invStd_host_ref({M}); + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + ck_tile::HostTensor beta_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); -#ifdef SAVE_MEAN_INV_STD - ck_tile::HostTensor mean_host_dev({M}); - ck_tile::HostTensor invStd_host_dev({M}); -#endif + ck_tile::HostTensor mean_host_ref({m}); + ck_tile::HostTensor invStd_host_ref({m}); - ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); -#ifdef SAVE_MEAN_INV_STD - ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); -#endif - x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); - layernorm2d_fwd_traits traits{data_type}; + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + layernorm2d_fwd_traits traits{data_type, SaveMeanVar}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(), -#ifdef SAVE_MEAN_INV_STD - mean_buf.GetDeviceBuffer(), - invStd_buf.GetDeviceBuffer(), -#else nullptr, nullptr, -#endif epsilon, - M, - N}; + m, + n, + stride}; - float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + float ave_time = layernorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); - std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + - sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + + sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "[" << data_type << "]" - << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" - << std::flush; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; bool pass = true; @@ -174,20 +133,59 @@ int main(int argc, char* argv[]) y_buf.FromDevice(y_host_dev.data()); - pass = ck_tile::check_err(y_host_dev, y_host_ref); + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } -#ifdef SAVE_MEAN_INV_STD - mean_buf.FromDevice(mean_host_dev.data()); - pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + return pass; +} - invStd_buf.FromDevice(invStd_host_dev.data()); - pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); -#endif +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + const std::string data_type = arg_parser.get_str("prec"); + int save_mv = arg_parser.get_int("save_mv"); + if(data_type == "fp16" && save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_mv) + { + return run(arg_parser) ? 0 : -2; } - std::cout << std::endl << std::flush; - - return !pass; + return -3; } diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index 4d1aac0994..861e4a0230 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -8,23 +8,114 @@ #include "ck_tile/ops/layernorm2d.hpp" #include -struct layernorm2d_fwd_traits +template +struct LayerNormTypeConfig; + +template <> +struct LayerNormTypeConfig { - std::string data_type; + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; + using ComputeDataType = float; +}; + +template <> +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using BetaDataType = ck_tile::bf16_t; + using MeanDataType = ck_tile::bf16_t; + using InvStdDataType = ck_tile::bf16_t; + using ComputeDataType = float; +}; + +// runtime args +struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct layernorm2d_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Layernorm2dShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; }; -struct layernorm2d_fwd_args +template +float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); + +// This is the public API, will be generated by script +struct layernorm2d_fwd_traits { - const void* p_x; - const void* p_gamma; - const void* p_beta; - void* p_y; - void* p_mean; - void* p_invStd; - float epsilon; - ck_tile::index_t M; - ck_tile::index_t N; + std::string data_type; + bool save_mean_var; }; -// host API float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh new file mode 100755 index 0000000000..bfb7f9ffe5 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -0,0 +1,38 @@ + +# run from top of ck folder +EXE=build/bin/tile_example_layernorm2d_fwd + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh new file mode 100755 index 0000000000..dcd40fda40 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# call from top of CK folder +EXE=./build/bin/tile_example_layernorm2d_fwd + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt new file mode 100644 index 0000000000..6caa38d50d --- /dev/null +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -0,0 +1,19 @@ +set(EXAMPLE_REDUCE "tile_example_reduce") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") + +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp new file mode 100644 index 0000000000..7973a8dfdb --- /dev/null +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -0,0 +1,110 @@ +#include "ck_tile/host.hpp" +#include "reduce.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using ADataType = DataType; + using AccDataType = float; + using BDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor a_host({m, n}); + ck_tile::HostTensor b_host_ref({m}); + ck_tile::HostTensor b_host_dev({m}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size " << kGridSize << std::endl; + + using Kernel = ck_tile::Reduce; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + m, + n)); + + std::size_t num_btype = sizeof(ADataType) * m * n + sizeof(BDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_reduce(a_host, b_host_ref); + b_buf.FromDevice(b_host_dev.mData.data()); + pass = ck_tile::check_err(b_host_dev, b_host_ref); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp new file mode 100644 index 0000000000..e36b468951 --- /dev/null +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename ThreadTile> // contiguous pixels(vector size) along seq +struct Reduce +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t Thread_M = ThreadTile::at(number<0>{}); + static constexpr index_t Thread_N = ThreadTile::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N; + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + __device__ static constexpr auto MakeABlockTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + __device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const + { + const auto a_m_n = make_naive_tensor_view( + p_a, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto iM = get_block_id() * Block_M; + + // A window + auto a_block_window = make_tile_window(a_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + MakeABlockTileDistribution()); + + const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + const ADataType reduce_init_value = 0; + + constexpr auto reduce_dims = sequence<1>{}; + + // Acc tile + // TODO: support cross warp reduction + auto acc_block_tensor = decltype(block_tile_reduce( + load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){}; + + // init Acc tile + tile_elementwise_inout( + [&](auto& acc) { acc = type_convert(reduce_init_value); }, + acc_block_tensor); + + // loop + index_t iN = 0; + + do + { + const auto a_block_tensor = load_tile(a_block_window); + + // FIXME: support cross warp reduction + block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce); + + move_tile_window(a_block_window, {0, Block_N}); + + iN += Block_N; + + } while(iN < N); + + // FIXME: support cross warp reduction + block_tile_reduce_sync(acc_block_tensor, f_reduce); + + // convert acc_block_tensor to b_block_tensor + const auto b_block_tensor = tile_elementwise_in( + [](const auto& acc) { return type_convert(acc); }, acc_block_tensor); + + // B + const auto b_m = make_naive_tensor_view_packed( + p_b, make_tuple(M), number<32>{}); + + // B window + auto b_block_window = make_tile_window(b_m, make_tuple(number{}), {iM}); + + // store B tile + store_tile(b_block_window, b_block_tensor); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index fe1e9c9edf..ec4a175d35 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) add_subdirectory(04_img2col) +add_subdirectory(05_reduce) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 4cddf6faa9..d96f14710b 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -52,6 +52,7 @@ #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 42508e66a6..a88780459b 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } +template +CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) +{ +#if 0 + return __shfl(v_local, src_lane); +#elif 1 + if constexpr(sizeof(int32_t) > sizeof(T)) + { + union packet + { + int32_t x; + T v; + }; + packet p; + p.v = v_local; + packet p_remote; + p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(p)); + + return p_remote.v; + } + else if constexpr(sizeof(int32_t) == sizeof(T)) + { + const int32_t v_remote_tmp = + __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(v_local)); + + return bit_cast(v_remote_tmp); + } + else + { + static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!"); + constexpr index_t elm = sizeof(T) / sizeof(int32_t); + using vector_type = thread_buffer; + auto vs = bit_cast(v_local); + auto vs_remote = vector_type{}; + static_for<0, elm, 1>{}([&](auto i_e) { + int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(vs[i_e])); + vs_remote(i_e) = tmp; + }); + return bit_cast(vs_remote); + } +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index a8bc27cdff..580faae925 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -32,11 +32,13 @@ #define CK_TILE_DEVICE inline __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_DEVICE_EXTERN __device__ +#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ #else #define CK_TILE_HOST inline #define CK_TILE_DEVICE inline #define CK_TILE_HOST_DEVICE inline #define CK_TILE_DEVICE_EXTERN +#define CK_TILE_HOST_DEVICE_EXTERN #endif #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index acf187cfc8..4fcea9642d 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +namespace impl { +template +struct reverse_slice_sequence_impl; + +template +struct reverse_slice_sequence_impl, + sequence, + sequence, + SliceSize> +{ + using old_scan = + reverse_slice_sequence_impl, sequence, sequence, SliceSize>; + + static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = + typename sequence_merge, typename old_scan::dim_lengths>::type; + using dim_slices = + typename sequence_merge, typename old_scan::dim_slices>::type; + using remaining_slice_sizes = typename sequence_merge< + std::conditional_t, sequence>, + typename old_scan::remaining_slice_sizes>::type; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.front().value == 1; + static constexpr index_t _split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t _split_idx = + std::conditional_t<_split_flag, number, number<0>>::value; + + static constexpr index_t split_flag = _split_flag || old_scan::split_flag; + static constexpr index_t split_idx = std:: + conditional_t, number<_split_idx>>::value; +}; + +template +struct reverse_slice_sequence_impl, sequence, sequence, SliceSize> +{ + static constexpr auto slice_size = SliceSize; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = sequence; + using dim_slices = sequence; + using remaining_slice_sizes = + std::conditional_t, sequence>; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.front().value == 1; + static constexpr index_t split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t split_idx = + std::conditional_t, number<0>>::value; +}; +} // namespace impl + +// clang-format off +// input a sequence(with optional mask), and the SliceSize : size per slice +// output the sequence each slice, and number of slices +// +// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2 +// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2 +// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1 +// +// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0 +// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0 +// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1 +// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2 +// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2 +// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2 +// +// <4, 2, 1, 4, 2> / 4 -> +// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0 +// +// return tuple, slice_index is at which index will start +// have split slices (right -> left) +// or the first index that sliced length is different from the original length +// clang-format on +template ::type> +constexpr auto reverse_slice_sequence(Seq, + number, + Mask = typename uniform_sequence_gen::type{}) +{ + static_assert(Seq::size() == Mask::size()); + using sliced_type = + impl::reverse_slice_sequence_impl::type, + SliceSize>; + static_assert(sliced_type::remaining_slice_sizes::front().value == 1, + "can not evenly divide this sequence, please check"); + return make_tuple(typename sliced_type::dim_lengths{}, + typename sliced_type::dim_slices{}, + number{}); +} + +template ::type> +constexpr auto slice_sequence(Seq, + number, + Mask = typename uniform_sequence_gen::type{}) +{ + constexpr auto r = + reverse_slice_sequence(Seq{}.reverse(), number{}, Mask{}.reverse()); + return make_tuple(r[number<0>{}].reverse(), + r[number<1>{}].reverse(), + number{}] - 1>{}); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index cb8c2c70c6..598dfeea3e 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence) +{ + return concat_tuple(f(x.at(number{}))...); +} + +} // namespace detail + +// make sure F return at least a tuple +// e.g. x : tuple, f will return tuple +// this function will return +template +CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x) +{ + return detail::embed_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); +} + // By default unroll to the flatten template CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t) diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 299a74bc08..29c20bed00 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor& out_ten }); } +// this function used inside span loop over +template +CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number) +{ + constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{}); + constexpr auto y_packs = number{}; + static_assert(y_size % y_packs == 0); + constexpr auto y_slice_size = y_size / y_packs; + + constexpr auto slice_info = slice_sequence(YLengths{}, number{}); + constexpr auto unpacks = slice_info[number<1>{}]; + return unpacks; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp index f1511f11d2..f82f6b5bcd 100644 --- a/include/ck_tile/core/tensor/sweep_tile.hpp +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) }); } +// unpacked span, this version support span with unpack(multi-arg) functor +// +template < + typename TileDistributedSpan_, // tile_distributed_span<...> + typename F, // signature: F(tile_distributed_index<...>) + typename Unpacks = typename uniform_sequence_gen::type> +CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {}) +{ + using DstrSpan = remove_cvref_t; + + static_uford{}( + [&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); }); +} + +namespace impl { + +template +struct sweep_tile_impl; + +template +struct sweep_tile_impl> +{ + CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto y_lengths = typename decltype(spans[number{}])::Impl{}; + constexpr auto x_unpacks = number{})>{}; + constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks); + return y_unpacks; + } + CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto u = + static_uford{}])::Impl, decltype(get_y_unpacks())>{}; + return u.get_num_of_access() * + sweep_tile_impl>{} + .get_num_of_access(); + } + template + CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + sweep_tile_uspan( + spans[number{}], + [&](auto... i_idx) { + const auto next_span_idx = embed_tuples( + [&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); }, + span_idx); + sweep_tile_impl>{}( + f, next_span_idx); + }, + get_y_unpacks()); + } + template + CK_TILE_HOST_DEVICE constexpr void + operator()(const F& f, const SpanIdx& span_idx, number) const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto u = + static_uford{}])::Impl, decltype(get_y_unpacks())>{}; + constexpr auto access_stride = + sweep_tile_impl>{} + .get_num_of_access(); + constexpr auto curr_i_access = number{}; + constexpr auto next_i_access = number{}; + u( + [&](auto... i_idx) { + const auto next_span_idx = embed_tuples( + [&](auto si) { + return make_tuple(concat_tuple( + si, make_tuple(detail::make_tile_distributed_index(i_idx)))...); + }, + span_idx); + sweep_tile_impl>{}( + f, next_span_idx, next_i_access); + }, + curr_i_access); + } +}; + +template +struct sweep_tile_impl> +{ + CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; } + template + CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const + { + unpack(f, span_idx); + } + template + CK_TILE_HOST_DEVICE constexpr void + operator()(const F& f, const SpanIdx& span_idx, number) const + { + unpack(f, span_idx); + } +}; + +template +struct sweep_tile_impl_0; + +// TODO: support empty tuple to remove this "entry-point" like function +template +struct sweep_tile_impl_0> +{ + CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto y_lengths = typename decltype(spans[number{}])::Impl{}; + constexpr auto x_unpacks = number{})>{}; + constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks); + return y_unpacks; + } + CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto u = + static_uford{}])::Impl, decltype(get_y_unpacks())>{}; + return u.get_num_of_access() * + sweep_tile_impl>{} + .get_num_of_access(); + } + template + CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + sweep_tile_uspan( + spans[number{}], + [&](auto... i_idx) { + constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...); + sweep_tile_impl>{}( + f, next_span_idx); + }, + get_y_unpacks()); + } + template + CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number) const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto u = + static_uford{}])::Impl, decltype(get_y_unpacks())>{}; + constexpr auto access_stride = + sweep_tile_impl>{} + .get_num_of_access(); + constexpr auto curr_i_access = number{}; + constexpr auto next_i_access = number{}; + u( + [&](auto... i_idx) { + constexpr auto next_span_idx = + make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...); + sweep_tile_impl>{}( + f, next_span_idx, next_i_access); + }, + curr_i_access); + } +}; + +} // namespace impl + +/* + * Enhanced sweep-tile utility, can control unpacks along each X-dim + * the lambda function argument is the distributed-idx, which can directly + * plugged into the distributed tensor as setter/getter + * + * e.g. below function, y with the type DistributedTensor, r is row scale + * + * // sweep tile 1 by 1 + * sweep_tile([&](auto idx) { + * constexpr auto row_id = make_tuple(idx[number<0>{}]); + * y(idx) = y(idx) * r(row_id); + * }); + * + * // sweep tile with 2 pixel from last dim each function call + * sweep_tile( + * [&](auto idx_0, auto idx_1) { + * constexpr auto row_id = make_tuple(idx_0[number<0>{}]); + * y(idx_0) = y(idx_0) * r(row_id); + * y(idx_1) = y(idx_1) * r(row_id); + * }, + * sequence<1, 2>{}); + * + * // sweep tile with 2x2 pixel each function call + * sweep_tile( + * [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) { + * constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]); + * constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]); + * y(idx_00) = y(idx_00) * r(row_id0); + * y(idx_01) = y(idx_01) * r(row_id0); + * y(idx_10) = y(idx_10) * r(row_id1); + * y(idx_11) = y(idx_11) * r(row_id1); + * }, + * sequence<2, 2>{}); + * + * TODO: do we need constexpr? lambda function could be non-constexpr + */ +template ::type> +CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {}) +{ + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + impl::sweep_tile_impl_0::type>{}(f); +} + +template ::type> +CK_TILE_HOST_DEVICE constexpr void +sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {}) +{ + sweep_tile(f, UnpacksPerXDim{}); +} + +/* + * construct a sweep tile instance, which support issue the lambda one by one + * Note that this struct will hold the lambda functor, but will not hold the distributed tensor + * the functionality is the same as sweep_tile() + */ +template ::type> +struct tile_sweeper +{ + using DistributedTensor = remove_cvref_t; + using F = remove_cvref_t; + using UnpacksPerXDim = remove_cvref_t; + + CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {} + CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {}) + : f(f_) + { + } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access() + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + constexpr auto tmp = + impl::sweep_tile_impl_0::type>{}; + return tmp.get_num_of_access(); + } + + CK_TILE_HOST_DEVICE void operator()() const + { + sweep_tile(f, UnpacksPerXDim{}); + } + + template + CK_TILE_HOST_DEVICE void operator()(number) const + { + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + impl::sweep_tile_impl_0::type>{}( + f, number{}); + } + F f; +}; + +// partial deduction is not allowed +// template +// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper; + +// deduction guide +template ::type> +CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 24c932f0a6..7761be492d 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -17,6 +17,14 @@ namespace ck_tile { +namespace detail { +template +CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) +{ + return Distribution::_get_partition_index(); +} +} // namespace detail + // distributed span template struct tile_distributed_span @@ -83,6 +91,21 @@ struct tile_distribution CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } + CK_TILE_HOST_DEVICE static auto _get_partition_index() + { + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + if constexpr(NDimP == 1) + { + return array{get_lane_id()}; + } + else if constexpr(NDimP == 2) + { + return array{get_warp_id(), get_lane_id()}; + } + } + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() { #if 0 @@ -149,6 +172,16 @@ struct tile_distribution } #endif + template + CK_TILE_HOST_DEVICE auto + calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const + { + const auto ps_ys_idx = container_concat(ps_idx, array{0}); + const auto window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx); + return window_adaptor_thread_coord_tmp.get_bottom_index(); + } + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() { constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_; @@ -421,6 +454,7 @@ struct tile_distribution_detail } // namespace detail +#if 0 // this returns a constexpr tile_distribution template CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_) @@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution detail::tile_distribution_detail>>{ ps_ys_to_xs_adaptor, ys_to_d_descriptor}; } +#endif // this returns a static tile_distribution template @@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr //*********************************************************************************** namespace detail { - -template -CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) -{ - // only support warp-tile and block-tile - static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!"); - - if constexpr(Distribution::NDimP == 1) - { - return array{get_lane_id()}; - } - else if constexpr(Distribution::NDimP == 2) - { - return array{get_warp_id(), get_lane_id()}; - } -} - -template -struct reverse_slice_sequence_impl; - -template -struct reverse_slice_sequence_impl, - sequence, - sequence, - SliceSize> -{ - using old_scan = - reverse_slice_sequence_impl, sequence, sequence, SliceSize>; - - static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value; - static constexpr auto slice_length = - std::conditional_t, number>::value; - - using dim_lengths = - typename sequence_merge, typename old_scan::dim_lengths>::type; - using dim_slices = - typename sequence_merge, typename old_scan::dim_slices>::type; - using remaining_slice_sizes = typename sequence_merge< - std::conditional_t, sequence>, - typename old_scan::remaining_slice_sizes>::type; - - // the first idx that sliced length not equal to original length - static constexpr index_t _flag = - slice_length != x && remaining_slice_sizes{}.front().value == 1; - static constexpr index_t _split_flag = std::conditional_t, number<0>>::value; - static constexpr index_t _split_idx = - std::conditional_t<_split_flag, number, number<0>>::value; - - static constexpr index_t split_flag = _split_flag || old_scan::split_flag; - static constexpr index_t split_idx = std:: - conditional_t, number<_split_idx>>::value; -}; - -template -struct reverse_slice_sequence_impl, sequence, sequence, SliceSize> -{ - static constexpr auto slice_size = SliceSize; - static constexpr auto slice_length = - std::conditional_t, number>::value; - - using dim_lengths = sequence; - using dim_slices = sequence; - using remaining_slice_sizes = - std::conditional_t, sequence>; - - // the first idx that sliced length not equal to original length - static constexpr index_t _flag = - slice_length != x && remaining_slice_sizes{}.front().value == 1; - static constexpr index_t split_flag = std::conditional_t, number<0>>::value; - static constexpr index_t split_idx = - std::conditional_t, number<0>>::value; -}; - -// clang-format off -// input a sequence(with optional mask), and the SliceSize : size per slice -// output the sequence each slice, and number of slices -// -// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0 -// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2 -// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2 -// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1 -// -// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0 -// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0 -// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0 -// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1 -// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2 -// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2 -// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2 -// -// <4, 2, 1, 4, 2> / 4 -> -// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0 -// -// return tuple, slice_index is at which index will start -// have split slices (right -> left) -// or the first index that sliced length is different from the original length -// clang-format on -template ::type> -constexpr auto reverse_slice_sequence(Seq, - number, - Mask = typename uniform_sequence_gen::type{}) -{ - static_assert(Seq::size() == Mask::size()); - using sliced_type = - reverse_slice_sequence_impl::type, - SliceSize>; - static_assert(sliced_type::remaining_slice_sizes::front().value == 1, - "can not evenly divide this sequence, please check"); - return make_tuple(typename sliced_type::dim_lengths{}, - typename sliced_type::dim_slices{}, - number{}); -} - // // slice tensor from x_dim, result in split in y_dim, not p_dim. // We don't support slice cross p_dim (aka, slice different threads) diff --git a/include/ck_tile/core/utility/functional_with_tuple.hpp b/include/ck_tile/core/utility/functional_with_tuple.hpp new file mode 100644 index 0000000000..4b40403190 --- /dev/null +++ b/include/ck_tile/core/utility/functional_with_tuple.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// This file should not be included inside tuple.hpp! + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include +#include + +namespace ck_tile { + +namespace detail { + +// RemainLengths: sequence<...> +// Orders: sequence<...> +template +struct static_uford_impl +{ + CK_TILE_HOST_DEVICE constexpr static_uford_impl() + { + static_assert(RemainLengths::size() > 0, "wrong! should not get here"); + static_assert(RamainUnpacks::size() > 0, "wrong! should not get here"); + } + + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const + { + constexpr index_t pack_len = RamainUnpacks::front(); + static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) { + constexpr auto new_pack = generate_tuple( + [&](auto idx_) { + constexpr auto i_new_pack = number{}; + constexpr auto i_pre_pack = number{}; + return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack); + }, + number{}); + + static_uford_impl{}(f, new_pack); + }); + } +}; + +template +struct static_uford_impl, sequence<>, Orders> +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const + { + constexpr auto origin_packs = transform_tuples( + [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{}); + unpack(f, origin_packs); + } +}; + +template +struct static_uford_one_shot_impl +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number) const + { + constexpr auto r_lens_stride = + reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{}); + constexpr auto r_upks_stride = + reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{}); + + constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front(); + constexpr index_t pack_len = RamainUnpacks::front(); + constexpr index_t current_idx = (current_acc / current_stride) * pack_len; + + constexpr auto new_pack = generate_tuple( + [&](auto idx_) { + constexpr auto i_new_pack = number{}; + constexpr auto i_pre_pack = number{}; + return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack); + }, + number{}); + + static_uford_one_shot_impl{}(f, new_pack, number{}); + } +}; + +template +struct static_uford_one_shot_impl, sequence<>, Orders> +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number) const + { + constexpr auto origin_packs = transform_tuples( + [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{}); + unpack(f, origin_packs); + } +}; + +} // namespace detail + +// TODO: we may unify static_ford/static_uford in the future +// +// loop over nd space(sequence) with packs +// you must make sure the function passed in has same number of argument +// +// e.g. +// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2> +// static_uford{}([&](auto i_0, auto i_1){}); // require 2 args(packs) +// +// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1> +// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3> +// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1> +// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3> +// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1> +// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3> +// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1> +// ... +template ::type, + class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type> +struct static_uford +{ + static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{}); + + CK_TILE_HOST_DEVICE constexpr static_uford() + { + static_assert(Lengths::size() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size"); + static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size"); + static_for<0, Lengths::size(), 1>{}( + [&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); }); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access() + { + using L_ = decltype(Lengths{} / Unpacks{}); + + return reduce_on_sequence(L_{}, multiplies{}, number<1>{}); + } + + // F signature: F(sequence<...> multi_id...) + // multi_id is the unordered multi-index + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{}); + constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{}); + detail::static_uford_impl{}( + f, make_tuple(sequence<>{})); + } + + // this version is friendly for issue function one by one + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, number) const + { + static_assert(i_access < get_num_of_access()); + constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{}); + constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{}); + detail::static_uford_one_shot_impl{}( + f, make_tuple(sequence<>{}), number{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index b382710b19..dbc1f5d23a 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -21,7 +21,7 @@ #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" -#include "ck_tile/host/reference/reference_layernorm2d.hpp" +#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" diff --git a/include/ck_tile/host/reference/reference_layernorm2d.hpp b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp similarity index 100% rename from include/ck_tile/host/reference/reference_layernorm2d.hpp rename to include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 3b66645ed4..2a403b0f49 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -4,6 +4,9 @@ #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" -#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" -#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 468df793da..cebe5131a7 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -5,37 +5,57 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/welford/thread/thread_welford.hpp" -#include "ck_tile/ops/welford/warp/warp_welford.hpp" namespace ck_tile { -// TODO: Extract some type to wrapper class -template -struct Layernorm2dFwd +// host side args +struct Layernorm2dFwdHostArgs { - using Problem = ck_tile::remove_cvref_t; + const void* p_x; + const void* p_gamma; + const void* p_beta; - using XDataType = ck_tile::remove_cvref_t; - using GammaDataType = ck_tile::remove_cvref_t; - using BetaDataType = ck_tile::remove_cvref_t; - using ComputeDataType = ck_tile::remove_cvref_t; - using YDataType = ck_tile::remove_cvref_t; - using MeanDataType = ck_tile::remove_cvref_t; - using InvStdDataType = ck_tile::remove_cvref_t; + void* p_y; + void* p_mean; + void* p_invStd; - static constexpr bool kHasGamma = !std::is_same_v; - static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMean = !std::is_same_v; - static constexpr bool kSaveInvStd = !std::is_same_v; + float epsilon; - static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + index_t m; + index_t n; + index_t stride; // row_stride +}; - static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; - static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; +// TODO: Extract some type to wrapper class +template +struct Layernorm2dFwd +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -52,400 +72,177 @@ struct Layernorm2dFwd float epsilon; - ck_tile::index_t M; - ck_tile::index_t N; + index_t m; + index_t n; + index_t stride; // row_stride }; + using Hargs = Layernorm2dFwdHostArgs; - CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, - const void* p_gamma, - const void* p_beta, - void* p_y, - void* p_mean, - void* p_invStd, - float epsilon, - ck_tile::index_t M, - ck_tile::index_t N) + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) { - return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; + return Kargs{hargs.p_x, + hargs.p_gamma, + hargs.p_beta, + hargs.p_y, + hargs.p_mean, + hargs.p_invStd, + hargs.epsilon, + hargs.m, + hargs.n, + hargs.stride}; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } - - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } - - CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) { - using S = typename Problem::BlockShape; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>{}); + return (hargs.m + Block_M - 1) / Block_M; } - CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() - { - using S = typename Problem::BlockShape; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, - tuple>, - tuple, sequence<0, 1>>, - tuple, sequence<1, 1>>, - sequence<1>, - sequence<2>>{}); - } - - CK_TILE_DEVICE static int GetWelfordMaxCount(int N) - { - constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; - - int thread_id_n = get_thread_id() % kNThreadPerBlock; - int max_count = - __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); - int n_per_block_tail_loop = - __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); - - if(n_per_block_tail_loop > 0) - { - int thread_max_n = (thread_id_n + 1) * kNPerThread; - int delta = thread_max_n - n_per_block_tail_loop; - delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread); - max_count += kNPerThread - delta; - } - - return max_count; - } + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } - template - CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, - const ComputeDataType epsilon) - { - // TODO: Investigate fast inverse square root algorithm with epsilon - constexpr auto spans = DistributedTensor::get_distributed_spans(); - - DistributedTensor out_dstr_tensor; + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - out_dstr_tensor(i_idx) = type_convert(1.0f) / - ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); - }); - - return out_dstr_tensor; - } + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } - template - CK_TILE_DEVICE std::enable_if_t - TwoPassLayernorm2dFwd(XBlockWindow& x_block_window, - GammaBlockWindow& gamma_block_window, - BetaBlockWindow& beta_block_window, - YBlockWindow& y_block_window, - MeanBlockWindow& mean_block_window, - InvStdBlockWindow& inv_std_block_window, - ComputeDataType epsilon, - ck_tile::index_t N) const + CK_TILE_HOST static std::string GetName() { - // TODO - Optimize tail loop to reduce move_tile_window() - index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); - - int welford_max_count = GetWelfordMaxCount(N); - ThreadWelford thread_welford{welford_max_count}; - - using XTensorType = decltype(load_tile(x_block_window)); - auto mean_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - auto var_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - - clear_tile(mean_compute_block_tensor); - clear_tile(var_compute_block_tensor); - - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - const auto x_block_tensor = load_tile(x_block_window); - - thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); - move_tile_window(x_block_window, {0, kNPerBlock}); - } - - // TODO: support cross warp Welford - WarpMergeWelford{}( - mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); - - auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); - - if constexpr(kSaveMean) - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - if constexpr(kSaveInvStd) - store_tile(inv_std_block_window, - cast_tile(inv_std_compute_block_tensor)); - - // reverse read x to reuse cache - ck_tile::index_t stride_to_right_most_window = - N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; - - move_tile_window(x_block_window, {0, -kNPerBlock}); - move_tile_window(gamma_block_window, {stride_to_right_most_window}); - move_tile_window(beta_block_window, {stride_to_right_most_window}); - move_tile_window(y_block_window, {0, stride_to_right_most_window}); - - // Normalization - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - const auto x_block_tensor = load_tile(x_block_window); - const auto gamma_block_tensor = load_tile(gamma_block_window); - const auto beta_block_tensor = load_tile(beta_block_window); - - constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); - - auto y_block_tensor = - make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); - - sweep_tile_span(x_spans[I1], [&](auto idx1) { - constexpr auto j_idx = make_tuple(idx1); - const auto gamma = type_convert(gamma_block_tensor[j_idx]); - const auto beta = type_convert(beta_block_tensor[j_idx]); - - sweep_tile_span(x_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - const auto mean = mean_compute_block_tensor[i_idx]; - const auto inv_std = inv_std_compute_block_tensor[i_idx]; - - const auto x = type_convert(x_block_tensor[i_j_idx]); - auto y = (x - mean) * inv_std * gamma + beta; - - y_block_tensor(i_j_idx) = type_convert(y); - }); - }); - - store_tile(y_block_window, y_block_tensor); - - move_tile_window(x_block_window, {0, -kNPerBlock}); - move_tile_window(gamma_block_window, {-kNPerBlock}); - move_tile_window(beta_block_window, {-kNPerBlock}); - move_tile_window(y_block_window, {0, -kNPerBlock}); - } - } - - template - CK_TILE_DEVICE std::enable_if_t - OnePassLayernorm2dFwd(XBlockWindow& x_block_window, - GammaBlockWindow& gamma_block_window, - BetaBlockWindow& beta_block_window, - YBlockWindow& y_block_window, - MeanBlockWindow& mean_block_window, - InvStdBlockWindow& inv_std_block_window, - ComputeDataType epsilon, - ck_tile::index_t N) const - { - int welford_max_count = GetWelfordMaxCount(N); - ThreadWelford thread_welford{welford_max_count}; - - using XTensorType = decltype(load_tile(x_block_window)); - auto mean_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - auto var_compute_block_tensor = - thread_welford.template MakeInitialMeanVarDistributedTensor(); - - clear_tile(mean_compute_block_tensor); - clear_tile(var_compute_block_tensor); - - const auto x_block_tensor = load_tile(x_block_window); - thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); - // TODO: support cross warp Welford - WarpMergeWelford{}( - mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); - - auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); - - if constexpr(kSaveMean) - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - if constexpr(kSaveInvStd) - store_tile(inv_std_block_window, - cast_tile(inv_std_compute_block_tensor)); - - // normalize - const auto gamma_block_tensor = load_tile(gamma_block_window); - const auto beta_block_tensor = load_tile(beta_block_window); - - constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); - - auto y_block_tensor = - make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); - - sweep_tile_span(x_spans[I1], [&](auto idx1) { - constexpr auto j_idx = make_tuple(idx1); - const auto gamma = type_convert(gamma_block_tensor[j_idx]); - const auto beta = type_convert(beta_block_tensor[j_idx]); - - sweep_tile_span(x_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - const auto mean = mean_compute_block_tensor[i_idx]; - const auto inv_std = inv_std_compute_block_tensor[i_idx]; - - const auto x = type_convert(x_block_tensor[i_j_idx]); - auto y = (x - mean) * inv_std * gamma + beta; - - y_block_tensor(i_j_idx) = type_convert(y); - }); - }); - - store_tile(y_block_window, y_block_tensor); + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kSaveMeanInvStd) n += "_mv"; + if (kTwoPass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("layernorm2d_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on } CK_TILE_DEVICE void operator()(Kargs kargs) const { - const auto x_m_n = [&]() { - const auto x_dram_naive = make_naive_tensor_view( + const auto iM = get_block_id() * Block_M; + + const auto x_window = [&]() { + const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.N, 1), - number{}, + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, number<1>{}); - return pad_tensor_view(x_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will + // check the max count dynamically + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); - const auto gamma_n = [&]() { - const auto gamma_dram_naive = make_naive_tensor_view( + const auto gamma_window = [&]() { + const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_gamma), - make_tuple(kargs.N), + make_tuple(kargs.n), make_tuple(1), - number{}, + number{}, number<1>{}); - return pad_tensor_view( - gamma_dram_naive, make_tuple(number{}), sequence{}); + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); }(); - const auto beta_n = [&]() { - const auto gamma_dram_naive = make_naive_tensor_view( + const auto beta_window = [&]() { + const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_beta), - make_tuple(kargs.N), + make_tuple(kargs.n), make_tuple(1), - number{}, + number{}, number<1>{}); - return pad_tensor_view( - gamma_dram_naive, make_tuple(number{}), sequence{}); + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + return make_tile_window(tmp2_, make_tuple(number{}, number{}), {0}); }(); - const auto iM = get_block_id() * kMPerBlock; - - constexpr auto xDstr = MakeXBlockTileDistribution(); - - auto x_block_window = make_tile_window( - x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); - - const auto y_m_n = [&]() { - const auto y_dram_naive = make_naive_tensor_view( + auto y_window = [&]() { + auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_y), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.N, 1), - number{}, + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, number<1>{}); - return pad_tensor_view(y_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); - auto y_block_window = make_tile_window( - y_m_n, make_tuple(number{}, number{}), {iM, 0}); - - constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); - constexpr auto betaDstr = gammaDstr; - - auto gamma_block_window = - make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); - - auto beta_block_window = make_tile_window( - beta_n, make_tuple(number{}, number{}), {0}, betaDstr); - - auto mean_block_window = [&]() { + auto mean_window = [&]() { if constexpr(kSaveMean) { const auto mean_m = [&]() { const auto mean_dram_naive = make_naive_tensor_view_packed( static_cast(kargs.p_mean), - make_tuple(kargs.M), + make_tuple(kargs.m), number<1>{}); return pad_tensor_view( - mean_dram_naive, make_tuple(number{}), sequence{}); + mean_dram_naive, make_tuple(number{}), sequence{}); }(); - - return make_tile_window(mean_m, make_tuple(number{}), {iM}); + return make_tile_window(mean_m, make_tuple(number{}), {iM}); } else - return make_null_tile_window(make_tuple(number{})); + return make_null_tile_window(make_tuple(number{})); }(); - auto inv_std_block_window = [&]() { + auto inv_std_window = [&]() { if constexpr(kSaveInvStd) { const auto inv_std_m = [&]() { const auto inv_std_dram_naive = make_naive_tensor_view_packed( static_cast(kargs.p_invStd), - make_tuple(kargs.M), + make_tuple(kargs.m), number<1>{}); return pad_tensor_view( - inv_std_dram_naive, make_tuple(number{}), sequence{}); + inv_std_dram_naive, make_tuple(number{}), sequence{}); }(); - - return make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + return make_tile_window(inv_std_m, make_tuple(number{}), {iM}); } else - return make_null_tile_window(make_tuple(number{})); + return make_null_tile_window(make_tuple(number{})); }(); - if(kargs.N <= kNPerBlock) - OnePassLayernorm2dFwd(x_block_window, - gamma_block_window, - beta_block_window, - y_block_window, - mean_block_window, - inv_std_block_window, - static_cast(kargs.epsilon), - kargs.N); - else - TwoPassLayernorm2dFwd(x_block_window, - gamma_block_window, - beta_block_window, - y_block_window, - mean_block_window, - inv_std_block_window, - static_cast(kargs.epsilon), - kargs.N); + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(x_window, + gamma_window, + beta_window, + y_window, + mean_window, + inv_std_window, + static_cast(kargs.epsilon), + kargs.n, + smem); } }; diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp new file mode 100644 index 0000000000..e4b60331eb --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +/* +// clang-format off + +4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector + + Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) + +<----------------------< Repeat_N(2)>--------------------->+ + | | + +<-- -->+ + Warp_N + +--------------+--------------+--------------+--------------+----+----------------+ + Warp_M | wrap_0 | wrap_1 | | ^ ^ + +--------------+--------------+ | | + | wrap_2 | wrap_3 | | v + +--------------+--------------+--------------+--------------+----+ Block_M + | | | + + + | + | | | v + +--------------+--------------+--------------+--------------+ + + + each Warp-tile (e.g 16 thrd per row) + + Vector_N (contiguous pixels each thrd holds along N, or vector size) + +-----------+-----------+-----------+-----------+-----------+ + | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M + +-----------+-----------+-----------+-----------+-----------+ + | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... + +-----------+-----------+-----------+-----------+-----------+ +// clang-format on +*/ +template + typename WarpPerBlock_, // num warps along seq + typename WarpTile_, // warp size, seq + typename Vector_, // contiguous pixels(vector size) along seq + index_t BlockSize_ = + warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> +struct Layernorm2dShape +{ + // block size + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); + + // warp size + static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); + + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // vector size along seq + static constexpr index_t Vector_M = Vector_::at(number<0>{}); + static constexpr index_t Vector_N = Vector_::at(number<1>{}); + + static_assert(Warp_M % Vector_M == 0); + static_assert(Warp_N % Vector_N == 0); + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t BlockSize = BlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp deleted file mode 100644 index 707a38f621..0000000000 --- a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp +++ /dev/null @@ -1,34 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/utility/type_traits.hpp" - -namespace ck_tile { - -template -struct BlockLayernorm2dFwdProblem -{ - using XDataType = remove_cvref_t; - using GammaDataType = remove_cvref_t; - using BetaDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using YDataType = remove_cvref_t; - using MeanDataType = remove_cvref_t; - using InvStdDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..6661cddf43 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/welford/block/block_welford_problem.hpp" +#include "ck_tile/ops/welford/block/block_welford.hpp" + +namespace ck_tile { + +struct Layernorm2dFwdPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template + CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() + { + using P_ = BlockWelfordProblem; + + return BlockWelford{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() + { + using P_ = BlockWelfordProblem; + + return BlockWelfordSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() + { + using P_ = BlockWelfordProblem; + + return BlockWelfordCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockWelfordProblem; + + using block_welford = BlockWelford; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using mean_var_block_tile = + decltype(block_welford::template MakeMeanVarBlockTile()); + + return GetBlockWelfordCrossWarpSync() + .template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp new file mode 100644 index 0000000000..d73bcb29e4 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Layernorm2dFwdPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr"; // block per row + else + return "wpr"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + const BetaWindow& beta_window_, + YWindow& y_window, + MeanWindow& mean_window, + InvStdWindow& inv_std_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + const auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + const auto beta_window = make_tile_window( + beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + + const auto x = load_tile(x_window); + int cur_count = 0; + int max_count = + block_tile_welford_calculate_max_count(row_size); + auto block_welford = Policy::template GetBlockWelford(); + auto block_welford_sync = Policy::template GetBlockWelfordSync(); + auto block_welford_cross_warp_sync = + Policy::template GetBlockWelfordCrossWarpSync(); + + // load gamma/beta (TODO: support no gamma/beta?) + const auto gamma = load_tile(gamma_window); + const auto beta = load_tile(beta_window); + + // compute welford each-thread->cross-lane->cross-warp + auto [mean, var] = block_welford(x, cur_count, max_count); + block_welford_sync(mean, var, cur_count); + block_welford_cross_warp_sync(mean, var, cur_count, smem); + block_tile_welford_post_scale_var(var, cur_count); + + // compute inv-std + auto inv_std = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_) + epsilon); + }, + var); + + if constexpr(kSaveMean) + store_tile(mean_window, cast_tile(mean)); + if constexpr(kSaveInvStd) + store_tile(inv_std_window, cast_tile(inv_std)); + + // layernorm computation + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(y, [&, mean_ = mean](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + const auto beta_ = type_convert(beta[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + + y(idx) = type_convert(y_); + }); + store_tile(y_window, y); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp new file mode 100644 index 0000000000..8e9f8e81e4 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct Layernorm2dFwdPipelineProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp new file mode 100644 index 0000000000..dcbfc87dab --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Layernorm2dFwdPipelineTwoPass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr"; // block per row + else + return "wpr"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + const BetaWindow& beta_window_, + YWindow& y_window, + MeanWindow& mean_window, + InvStdWindow& inv_std_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + auto beta_window = make_tile_window( + beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + + // Problem::BlockShape + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + // total number of count assume current iter have no pad(only last iter has pad) + constexpr index_t count_per_iter = + Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N; + const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N; + + int cur_count = 0; + int max_count = + (num_n_tile_iteration - 1) * count_per_iter + + block_tile_welford_calculate_max_count(last_iter_n); + auto block_welford = Policy::template GetBlockWelford(); + auto block_welford_sync = Policy::template GetBlockWelfordSync(); + auto block_welford_cross_warp_sync = + Policy::template GetBlockWelfordCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto mean = block_welford.template MakeMeanVarBlockTile(); + auto var = block_welford.template MakeMeanVarBlockTile(); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_welford(x, mean, var, cur_count, max_count); + move_tile_window(x_window, {0, Block_N}); + } + + block_welford_sync(mean, var, cur_count); + block_welford_cross_warp_sync(mean, var, cur_count, smem); + block_tile_welford_post_scale_var(var, cur_count); + + // compute inv-std + auto inv_std = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_) + epsilon); + }, + var); + + if constexpr(kSaveMean) + store_tile(mean_window, cast_tile(mean)); + if constexpr(kSaveInvStd) + store_tile(inv_std_window, cast_tile(inv_std)); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + // x_window.foo(); + // gamma_window.foo(); + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {stride_to_right_most_window}); + move_tile_window(beta_window, {stride_to_right_most_window}); + move_tile_window(y_window, {0, stride_to_right_most_window}); + + // layernorm computation + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + // load gamma/beta (TODO: support no gamma/beta?) + const auto gamma = load_tile(gamma_window); + const auto beta = load_tile(beta_window); + + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(y, [&, mean_ = mean](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + const auto beta_ = type_convert(beta[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + + y(idx) = type_convert(y_); + }); + + store_tile(y_window, y); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {-Block_N}); + move_tile_window(beta_window, {-Block_N}); + move_tile_window(y_window, {0, -Block_N}); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp deleted file mode 100644 index 1ff541d844..0000000000 --- a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -template // Sequence<... -struct TileLayernorm2dShape -{ - static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); - static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); - - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); - - static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; - - static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); - static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); - - static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; - static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; - - // TODO - kNNumWarps can only be 1 if we don't support cross warp welford - static_assert(kNWarpPerBlock == 1); - - static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 682d60d872..63c364331d 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp index dffaad7501..ebf9406837 100644 --- a/include/ck_tile/ops/welford.hpp +++ b/include/ck_tile/ops/welford.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/welford/block/block_welford.hpp" +#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/welford/thread/thread_welford.hpp" -#include "ck_tile/ops/welford/warp/warp_welford.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp new file mode 100644 index 0000000000..55d55402d8 --- /dev/null +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/welford/thread/thread_welford.hpp" + +namespace ck_tile { + +template +struct BlockWelford +{ + using Problem = remove_cvref_t; + using XDataType = typename Problem::XDataType; + using ComputeDataType = typename Problem::ComputeDataType; + + CK_TILE_DEVICE constexpr BlockWelford() {} + + // [CAUSION] - max_count_ is to deal with the padding problem + // max_count_ is depend on caller, eg: naive and splitN welford will have different + // calculation of max_count_ + // -> use block_welford_calculate_max_count to compute + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor, + int& cur_count_, // -> prefer init as zero + const int& max_count_) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + welford_update( + mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x, cur_count_); + }); + } + }); + } + + template + CK_TILE_DEVICE static auto MakeMeanVarBlockTile() + { + static_assert(std::is_same_v, "wrong!"); + + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + + return tensor; + } + + template + CK_TILE_DEVICE auto + operator()(const XDistributedTensor_& x_tensor, int& cur_count_, const int& max_count_) + { + auto mean_tensor = MakeMeanVarBlockTile(); + auto var_tensor = MakeMeanVarBlockTile(); + clear_tile(mean_tensor); + clear_tile(var_tensor); + + (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_); + + return ck_tile::make_tuple(mean_tensor, var_tensor); + } +}; + +template +struct BlockWelfordSync +{ + using Problem = remove_cvref_t; + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + // const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + // const auto rs_idx = + // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // xor + index_t src_lane = + (__lane_id()) ^ + (number{}.value); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); + const auto v_remote_var = warp_shuffle(v_local_var, src_lane); + const auto v_remote_count = warp_shuffle(v_local_count, src_lane); + + // welford merge + welford_merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + } + }); + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + var_tensor.get_thread_buffer()(i) = v_local_var; + + count = v_local_count; + }); + } +}; + +template +struct BlockWelfordCrossWarpSync +{ + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + + template + CK_TILE_DEVICE static constexpr index_t GetReduceWarps() + { + constexpr index_t num_reduce_warps = [&]() { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_warp = 0; + + index_t len_ = 1; + static_for<0, NDimR, 1>{}([&](auto idim_r) { + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + len_ *= r_length; + } + }); + return len_; + }(); + return num_reduce_warps; + } + + // return in byte + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + // constexpr auto num_reduce_warps = GetReduceWarps(); + + // data need to exchange is very small, we just pack mean+var+count -> 4dword + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + + // we need to store all data from every wave into smem + // e.g. 2x2 reduce along N + // -------------> reduce N + // | w0 | w1 | ___> | w01 | + // | w2 | w3 | | w23 | + // + // -> store data from every wave into LDS + // + // + // -------------> reduce N + // | w0 | w1 | w2 | w3 | -----> | w0123 | + // + // -> also store data from every wave into LDS + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + return num_warps * 4 * thread_buf_size * sizeof(float); + } + + template + CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor, + int& count, + void* smem) + { + using DataType = typename MeanDistributedTensor_::DataType; + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + // using DstrEncode = typename Dstr::DstrEncode; + // using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + // Note: we always pack everything into fp32x4 + fp32x4_t* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + constexpr auto num_reduce_warps = GetReduceWarps(); + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + const index_t smem_offset = warp_id; + + // skip if nonthing to do + if constexpr(num_reduce_warps == 1) + return; + + // store into smem only for lane-0 within one warp + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + fp32x4_t local_scratch_; + local_scratch_[0] = bit_cast(mean_tensor.get_thread_buffer()[i]); + local_scratch_[1] = bit_cast(var_tensor.get_thread_buffer()[i]); + local_scratch_[2] = bit_cast(count); + + smem_ptr[smem_offset + i * num_warps] = local_scratch_; + }); + } + block_sync_lds(); + + // load from smem. here we let everythread to do compute :) + index_t local_warp_id = warp_id / num_reduce_warps; + index_t local_smem_os = local_warp_id * num_reduce_warps; + fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { + all_scratch[i_0 * num_warps + i_1] = + smem_ptr[i_0 * num_reduce_warps + local_smem_os + i_1]; + }); + }); + block_sync_lds(); // TODO: we don't need sync here + + // const int original_count = count; + + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + // TODO: use descriptor for this + auto v_local = all_scratch[i_0 * num_warps]; + auto v_local_mean = bit_cast(v_local[0]); + auto v_local_var = bit_cast(v_local[1]); + auto v_local_count = bit_cast(v_local[2]); + + // further reduce mean/var + static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { + constexpr auto i_1 = number{}; + const fp32x4_t v_remote = all_scratch[i_0 * num_warps + i_1]; + const auto v_remote_mean = bit_cast(v_remote[0]); + const auto v_remote_var = bit_cast(v_remote[1]); + const auto v_remote_count = bit_cast(v_remote[2]); + + welford_merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + + mean_tensor.get_thread_buffer()(i_0) = v_local_mean; + var_tensor.get_thread_buffer()(i_0) = v_local_var; + + count = v_local_count; + }); + } +}; + +// compute the max count for a last dim reduce +// everything may have vector/repeat, so the max count could be uneven +// TODO: specify which dim to compute and proper set the problem +// TODO: BlockShape we reuse layernorm_fwd_shape :) +template +CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size) +{ +#if 0 + using S = BlockShape; + index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N; + constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N; + index_t iNLane = get_thread_id() % NThread; + index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N); + index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N; + index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N; + index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0; + return iN0 * S::Vector_N + iN3; +#endif + using S_ = BlockShape; + constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N; + + // TODO: we always check vector size, need be evenly devidable by vector-n + const index_t element_per_row = row_size / S_::Vector_N; + index_t lane_id_n = get_thread_id() % ThreadsPerBlock_N; + + index_t cnt = 0; + // TODO: Repeat_N can not be too long, otherwise this is not good + static_for<0, S_::Repeat_N, 1>{}([&](auto) { + index_t _a = lane_id_n < element_per_row ? 1 : 0; + cnt += _a; + lane_id_n += ThreadsPerBlock_N; + }); + return cnt * S_::Vector_N; +} + +// Note: this function must be called after all the computation +template +CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor, + int count) +{ + using DataType = typename VarDistributedTensor_::DataType; + tile_elementwise_inout([&count](auto& x) { x = x / type_convert(count); }, + var_tensor); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/block/block_welford_problem.hpp b/include/ck_tile/ops/welford/block/block_welford_problem.hpp new file mode 100644 index 0000000000..dcae1ef2ee --- /dev/null +++ b/include/ck_tile/ops/welford/block/block_welford_problem.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockWelfordProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp index 2ca9a23657..4c61cdcf4b 100644 --- a/include/ck_tile/ops/welford/thread/thread_welford.hpp +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -7,95 +7,30 @@ namespace ck_tile { -template -struct ThreadWelford +template +CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count) { - using XDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - - template - CK_TILE_DEVICE void Update(T& mean, T& var, T x) - { - if(ck_tile::isnan(x)) - { - mean = x; - var = x; - } - else - { - T delta = x - mean; - mean += delta / cur_count_; - T delta2 = x - mean; - var += delta * delta2; - } - } - - // [CAUSION] - max_count_ is to deal with the padding problem - // max_count_ is depend on caller, eg: naive and splitN welford will have different - // calculation of max_count_ - CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {} - - template - CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, - MeanDistributedTensor_& mean_tensor, - VarDistributedTensor_& var_tensor) - { - constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; - - constexpr auto spans = XDistributedTensor_::get_distributed_spans(); - - sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { - if(cur_count_ < max_count_) - { - ++cur_count_; - - sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { - constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); - constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); - - auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); - - Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x); - }); - } - }); - } - - template - CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor() - { - static_assert(std::is_same_v, "wrong!"); - - constexpr auto reduce_dims = sequence<1>{}; - - constexpr auto dstr = - make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( - XDistributedTensor_::get_tile_distribution() - .get_static_tile_distribution_encoding(), - reduce_dims)); - - auto tensor = make_static_distributed_tensor(dstr); - clear_tile(tensor); - - return tensor; - } - - template - CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor) - { - auto mean_tensor = MakeInitialMeanVarDistributedTensor(); - auto var_tensor = MakeInitialMeanVarDistributedTensor(); - - (*this)(x_tensor, mean_tensor, var_tensor); - - return ck_tile::make_tuple(mean_tensor, var_tensor); - } - - int cur_count_; - int max_count_; -}; + // TODO: check nan? maybe no + T delta = x - mean; + mean += delta / count; + T delta2 = x - mean; + var += delta * delta2; +} + +template +CK_TILE_DEVICE static void +welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) +{ + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a_ * count_b_over_count; + count_a = count; +} } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp deleted file mode 100644 index 687b61f430..0000000000 --- a/include/ck_tile/ops/welford/warp/warp_welford.hpp +++ /dev/null @@ -1,154 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct WarpMergeWelford -{ - using ComputeDataType = remove_cvref_t; - - template - CK_TILE_DEVICE static void - Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) - { - int count = count_a + count_b; - T count_ = type_convert(count); - T count_a_ = type_convert(count_a); - T count_b_ = type_convert(count_b); - T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; - - T delta = mean_b - mean_a; - mean_a += delta * count_b_over_count; - var_a += var_b + delta * delta * count_a_ * count_b_over_count; - count_a = count; - } - - template - CK_TILE_DEVICE void - operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) - { - using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; - using DstrEncode = typename Dstr::DstrEncode; - using DstrEncodeDetail = typename DstrEncode::detail; - - static_assert(std::is_same_v, - "wrong!"); - - constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); - constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); - - constexpr index_t idim_p_lane = NDimP - 1; - - const auto ps_idx = make_array(get_warp_id(), get_lane_id()); - const auto rs_idx = - mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); - - constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); - static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); - - const int original_count = count; - - // loop over thread data - static_for<0, thread_buf_size, 1>{}([&](auto i) { - auto v_local_mean = mean_tensor.get_thread_buffer()[i]; - auto v_local_var = var_tensor.get_thread_buffer()[i]; - auto v_local_count = original_count; - - // cross-lane reduce for replication - // only reduce on R dimension correspond to lane - // (lane id maps to this R dimension) - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; - - static_assert(is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); - - constexpr index_t nstage = integer_log2_floor(r_length); - - // reduction sweep forward - static_for<0, nstage, 1>{}([&](auto istage) { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); - - // pull data from remote lane - const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); - const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); - const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); - - // welford merge - Merge(v_local_mean, - v_local_var, - v_local_count, - v_remote_mean, - v_remote_var, - v_remote_count); - }); - } - }); - - // cross-lane broadcast for replication - // only broadcast on R dimension correspond to lane - // (lane id maps to this R dimension) - if constexpr(BroadcastLane) - { - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - const index_t r_id = rs_idx[idim_r]; - - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; - - static_assert(is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); - - constexpr index_t nstage = integer_log2_floor(r_length); - - // broadcast sweep backward - static_for<0, nstage, 1>{}([&](auto istage) { - // do I hold reduced data? - const bool do_i_hold_reduced_data = r_id < (1 << istage); - - constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); - - // pull data from remote lane - const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); - const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); - const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); - - // decide whether to update local data with remote data - v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; - v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; - v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; - }); - } - }); - } - - mean_tensor.get_thread_buffer()(i) = v_local_mean; - - if constexpr(GetActualVariance) - var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count; - else - var_tensor.get_thread_buffer()(i) = v_local_var; - - count = v_local_count; - }); - } -}; - -} // namespace ck_tile