diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index 3f5ee596819a..53c5718c1da5 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -285,8 +285,9 @@ class BasePopulateParams { // class PopulateParams : public BasePopulateParams { private: - static constexpr size_t nInitParameters = 30; - static const InitParamsNonAccel initParameters[nInitParameters]; +#define NonAccel_DECLARATIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef NonAccel_DECLARATIONS_GEN // if can't select config from above , use this config to do // padding kernel for example , GemmK/block is 16 , if your gemmK is 13 , we // add more 3 gemmk @@ -387,19 +388,9 @@ class PopulateParamsAccel : public BasePopulateParams { // Xdlops interface // class PopulateParamsXDL : public PopulateParamsAccel { - static constexpr size_t nInitParameters = 40; - // Initial tuning parameters for forward convolution and backward - // convolution. - static const InitParamsAccel initParameters[nInitParameters]; - - static constexpr size_t nInitParametersFp16 = 40; - // Tuning parameters for fp16/bf16 convolutions. - static const InitParamsAccel initParametersFp16[nInitParametersFp16]; - - static constexpr size_t nInitParametersForward8Bit = 40; - // Tuning parameters for i8 convolutions. - static const InitParamsAccel - initParametersForward8Bit[nInitParametersForward8Bit]; +#define XDL_DECLARATIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef XDL_DECLARATIONS_GEN public: std::vector @@ -425,14 +416,9 @@ class PopulateParamsXDL : public PopulateParamsAccel { // class PopulateParamsWmma : public PopulateParamsAccel { private: - static constexpr size_t nInitParametersFp16 = 30; - // Tuning parameters for fp16/bf16 convolutions. - static const InitParamsAccel initParametersFp16[nInitParametersFp16]; - - static constexpr size_t nInitParametersForward8Bit = 30; - // Tuning parameters for i8 convolutions. - static const InitParamsAccel - initParametersForward8Bit[nInitParametersForward8Bit]; +#define Wmma_DECLARATIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef Wmma_DECLARATIONS_GEN public: std::vector diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc b/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc new file mode 100644 index 000000000000..87108f23834b --- /dev/null +++ b/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc @@ -0,0 +1,353 @@ +// THIS IS AN AUTOGENERATED FILE. +// DO NOT EDIT THIS FILE DIRECTLY! + +// clang-format off +#ifdef NonAccel_DEFINITIONS_GEN + +// BEGIN_GEMM_NonAccel_f32_DEFS +const InitParamsNonAccel PopulateParams::initParametersGemm[PopulateParams::nInitParametersGemm] = { + {64,64,32,8,2,2,1}, + {128,64,128,4,4,2,1}, + {128,128,128,4,2,4,1}, + {128,128,128,16,4,4,1}, + {64,32,128,16,2,4,1}, + {128,128,128,16,2,2,1}, + {64,64,128,8,2,4,1}, + {128,128,128,16,2,4,1}, + {256,32,128,8,2,2,1}, + {64,32,64,8,4,2,1}, + {64,32,32,4,2,4,1}, + {128,32,32,16,2,2,1}, + {64,64,64,8,2,2,1}, + {64,64,32,16,4,4,1} +}; +// END_GEMM_NonAccel_f32_DEFS + +// BEGIN_CONV_NonAccel_f32_DEFS +const InitParamsNonAccel PopulateParams::initParametersConv[PopulateParams::nInitParametersConv] = { + {128,128,32,16,2,4,1}, + {64,64,128,16,2,2,1}, + {64,128,32,8,2,2,1}, + {128,32,32,16,2,2,1}, + {64,64,128,16,4,4,1}, + {64,32,64,4,2,2,1}, + {64,128,64,16,2,4,1}, + {64,32,64,8,2,2,1}, + {128,32,64,8,2,2,1}, + {256,32,32,16,2,2,1}, + {64,32,32,16,4,4,1}, + {64,64,128,4,4,4,1}, + {128,128,128,4,2,4,1}, + {128,64,32,16,2,2,1}, + {64,128,64,16,4,2,1}, + {256,128,128,8,2,4,1}, + {64,32,32,8,2,4,1}, + {128,64,32,4,2,2,1}, + {128,128,64,4,2,2,1}, + {128,128,128,16,2,4,1}, + {64,128,32,16,2,4,1}, + {64,64,32,8,2,2,1}, + {64,64,64,16,2,4,1}, + {64,32,128,8,2,4,1} +}; +// END_CONV_NonAccel_f32_DEFS + +#endif + +#ifdef NonAccel_DECLARATIONS_GEN + +// BEGIN_GEMM_NonAccel_f32_DECS +static constexpr size_t nInitParametersGemm = 14; +static const InitParamsNonAccel initParametersGemm[nInitParametersGemm]; +// END_GEMM_NonAccel_f32_DECS + +// BEGIN_CONV_NonAccel_f32_DECS +static constexpr size_t nInitParametersConv = 24; +static const InitParamsNonAccel initParametersConv[nInitParametersConv]; +// END_CONV_NonAccel_f32_DECS + +#endif + +#ifdef XDL_DEFINITIONS_GEN + +// BEGIN_GEMM_XDL_f32_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersGemm[PopulateParamsXDL::nInitParametersGemm] = { + {256,256,2,128,32,4,1,true,true}, + {32,32,8,16,16,8,1,true,true}, + {32,16,8,16,16,8,1,true,true}, + {64,32,4,16,16,8,1,true,true}, + {64,64,8,64,16,8,1,true,true}, + {16,64,4,16,16,4,1,true,true}, + {16,32,8,16,16,8,1,true,true}, + {64,64,8,16,16,4,1,true,true} +}; +// END_GEMM_XDL_f32_DEFS + +// BEGIN_CONV_XDL_f32_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersConv[PopulateParamsXDL::nInitParametersConv] = { + {64,64,4,64,16,8,1,true,true}, + {64,64,4,16,16,4,1,true,true}, + {32,64,8,32,16,8,1,true,true}, + {64,128,4,32,32,4,1,true,true}, + {128,128,4,128,16,4,1,true,true}, + {64,256,8,16,16,1,1,true,true}, + {128,128,4,128,32,1,1,true,true}, + {256,64,2,128,32,8,1,true,true}, + {128,64,4,128,16,8,1,true,true}, + {32,16,4,16,16,8,1,true,true}, + {64,64,8,32,32,4,1,true,true}, + {256,32,8,64,16,1,1,true,true}, + {64,32,8,16,16,4,1,true,true}, + {16,16,4,16,16,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {64,256,4,64,16,4,1,true,true}, + {32,32,8,16,16,8,1,true,true}, + {256,128,2,64,32,8,1,true,true}, + {16,16,8,16,16,8,1,true,true}, + {64,64,8,16,16,8,1,true,true} +}; +// END_CONV_XDL_f32_DEFS + +// BEGIN_GEMM_XDL_f16_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersFp16Gemm[PopulateParamsXDL::nInitParametersFp16Gemm] = { + {128,256,4,128,32,4,1,true,true}, + {16,16,8,16,16,8,1,true,true}, + {64,128,8,32,32,8,1,true,true}, + {128,128,8,64,16,8,1,true,true}, + {64,128,4,32,32,8,1,true,true}, + {128,128,4,128,16,8,1,true,true}, + {128,256,4,64,16,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {256,256,8,128,16,4,1,true,true}, + {16,64,8,16,16,8,1,true,true}, + {32,64,8,16,16,8,1,true,true}, + {64,64,8,32,32,8,1,true,true} +}; +// END_GEMM_XDL_f16_DEFS + +// BEGIN_CONV_XDL_f16_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersFp16Conv[PopulateParamsXDL::nInitParametersFp16Conv] = { + {128,128,4,64,16,8,1,true,true}, + {64,128,4,32,32,8,1,true,true}, + {256,128,2,128,32,4,1,true,true}, + {64,32,8,16,16,8,1,true,true}, + {64,64,4,64,16,8,1,true,true}, + {256,128,4,128,32,8,1,true,true}, + {256,64,8,64,16,4,1,true,true}, + {256,64,4,128,32,8,1,true,true}, + {128,128,4,64,32,8,1,true,true}, + {128,256,4,64,16,4,1,true,true}, + {128,128,4,32,32,4,1,true,true}, + {64,32,8,32,32,8,1,true,true}, + {128,64,8,64,32,8,1,true,true}, + {128,64,4,128,16,8,1,true,true}, + {64,64,2,32,32,4,1,true,true}, + {64,64,8,32,32,8,1,true,true}, + {64,128,8,64,32,8,1,true,true}, + {16,256,4,16,16,4,1,true,true}, + {128,256,2,128,32,4,1,true,true}, + {128,128,8,128,32,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {128,32,8,32,32,8,1,true,true}, + {256,128,4,128,32,4,1,true,true}, + {32,32,8,16,16,8,1,true,true}, + {32,16,8,16,16,8,1,true,true}, + {256,128,8,128,16,4,1,true,true}, + {16,16,8,16,16,8,1,true,true}, + {128,256,4,64,32,8,1,true,true}, + {32,64,8,32,16,4,1,true,true} +}; +// END_CONV_XDL_f16_DEFS + +// BEGIN_GEMM_XDL_i8_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersForward8BitGemm[PopulateParamsXDL::nInitParametersForward8BitGemm] = { + {64,64,16,32,16,4,1,true,true}, + {64,128,8,32,16,8,1,true,true}, + {32,64,8,16,16,16,1,true,true}, + {16,32,8,16,16,8,1,true,true}, + {32,64,16,16,16,4,1,true,true}, + {32,256,4,16,16,16,1,true,true}, + {64,128,32,64,32,4,1,true,true}, + {64,256,8,64,16,8,1,true,true}, + {256,256,8,128,128,1,1,true,true}, + {64,16,16,16,16,16,1,true,true}, + {16,32,32,16,16,8,1,true,true}, + {32,32,16,16,16,16,1,true,true}, + {64,16,8,16,16,16,1,true,true}, + {16,64,8,16,16,8,1,true,true} +}; +// END_GEMM_XDL_i8_DEFS + +// BEGIN_CONV_XDL_i8_DEFS +const InitParamsAccel PopulateParamsXDL::initParametersForward8BitConv[PopulateParamsXDL::nInitParametersForward8BitConv] = { + {64,32,4,64,16,16,1,true,true}, + {64,128,32,32,16,4,1,true,true}, + {64,16,8,32,16,8,1,true,true}, + {64,128,4,64,16,16,1,true,true}, + {256,64,4,128,32,16,1,true,true}, + {64,64,8,32,32,16,1,true,true}, + {128,64,4,32,16,16,1,true,true}, + {64,32,8,32,16,16,1,true,true}, + {128,128,4,64,16,16,1,true,true}, + {64,64,16,32,32,4,1,true,true}, + {128,128,16,128,16,8,1,true,true}, + {128,32,4,32,32,16,1,true,true}, + {64,256,4,32,16,4,1,true,true}, + {32,32,16,32,16,16,1,true,true}, + {64,64,16,16,16,16,1,true,true} +}; +// END_CONV_XDL_i8_DEFS + +#endif + +#ifdef XDL_DECLARATIONS_GEN + +// BEGIN_GEMM_XDL_f32_DECS +static constexpr size_t nInitParametersGemm = 8; +static const InitParamsAccel initParametersGemm[nInitParametersGemm]; +// END_GEMM_XDL_f32_DECS + +// BEGIN_CONV_XDL_f32_DECS +static constexpr size_t nInitParametersConv = 20; +static const InitParamsAccel initParametersConv[nInitParametersConv]; +// END_CONV_XDL_f32_DECS + +// BEGIN_GEMM_XDL_f16_DECS +static constexpr size_t nInitParametersFp16Gemm = 12; +static const InitParamsAccel initParametersFp16Gemm[nInitParametersFp16Gemm]; +// END_GEMM_XDL_f16_DECS + +// BEGIN_CONV_XDL_f16_DECS +static constexpr size_t nInitParametersFp16Conv = 29; +static const InitParamsAccel initParametersFp16Conv[nInitParametersFp16Conv]; +// END_CONV_XDL_f16_DECS + +// BEGIN_GEMM_XDL_i8_DECS +static constexpr size_t nInitParametersForward8BitGemm = 14; +static const InitParamsAccel initParametersForward8BitGemm[nInitParametersForward8BitGemm]; +// END_GEMM_XDL_i8_DECS + +// BEGIN_CONV_XDL_i8_DECS +static constexpr size_t nInitParametersForward8BitConv = 15; +static const InitParamsAccel initParametersForward8BitConv[nInitParametersForward8BitConv]; +// END_CONV_XDL_i8_DECS + +#endif + +#ifdef Wmma_DEFINITIONS_GEN + +// BEGIN_GEMM_Wmma_f16_DEFS +const InitParamsAccel PopulateParamsWmma::initParametersFp16Gemm[PopulateParamsWmma::nInitParametersFp16Gemm] = { + {128,256,2,32,32,8,1,true,true}, + {64,256,4,64,32,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {128,64,2,64,64,16,1,true,true}, + {128,64,8,32,64,8,1,true,true}, + {16,128,4,16,128,8,1,true,true}, + {64,128,8,64,32,8,1,true,true}, + {16,16,8,16,16,16,1,true,true}, + {32,16,8,16,16,16,1,true,true}, + {128,256,4,32,128,8,1,true,true}, + {128,16,8,32,16,8,1,true,true}, + {64,128,4,64,64,8,1,true,true}, + {64,32,8,16,32,8,1,true,true}, + {128,256,8,128,32,8,1,true,true}, + {16,16,4,16,16,8,1,true,true}, + {128,64,2,32,64,8,1,true,true}, + {16,32,4,16,32,16,1,true,true} +}; +// END_GEMM_Wmma_f16_DEFS + +// BEGIN_CONV_Wmma_f16_DEFS +const InitParamsAccel PopulateParamsWmma::initParametersFp16Conv[PopulateParamsWmma::nInitParametersFp16Conv] = { + {16,16,4,16,16,8,1,true,true}, + {256,128,8,128,32,8,1,true,true}, + {256,64,2,64,64,8,1,true,true}, + {64,64,4,32,32,8,1,true,true}, + {128,128,2,32,32,8,1,true,true}, + {64,16,8,16,16,16,1,true,true}, + {128,64,8,32,64,8,1,true,true}, + {256,256,8,64,32,8,1,true,true}, + {64,128,8,64,32,8,1,true,true}, + {128,64,2,64,32,8,1,true,true}, + {128,256,2,64,32,8,1,true,true}, + {16,16,8,16,16,8,1,true,true}, + {128,32,2,32,32,8,1,true,true}, + {128,256,8,128,32,8,1,true,true}, + {32,128,2,32,32,8,1,true,true}, + {64,256,4,32,64,8,1,true,true}, + {64,32,8,32,32,8,1,true,true}, + {64,256,2,64,64,8,1,true,true}, + {16,32,4,16,16,16,1,true,true}, + {16,32,4,16,32,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {256,128,4,32,64,8,1,true,true}, + {128,256,4,64,32,8,1,true,true}, + {128,128,4,64,64,8,1,true,true}, + {16,128,8,16,16,8,1,true,true}, + {128,16,8,32,16,8,1,true,true} +}; +// END_CONV_Wmma_f16_DEFS + +// BEGIN_GEMM_Wmma_i8_DEFS +const InitParamsAccel PopulateParamsWmma::initParametersForward8BitGemm[PopulateParamsWmma::nInitParametersForward8BitGemm] = { + {128,64,8,32,64,16,1,true,true}, + {64,128,8,64,32,16,1,true,true}, + {128,128,4,64,64,16,1,true,true}, + {64,32,8,16,16,16,1,true,true}, + {256,32,2,64,32,16,1,true,true}, + {256,128,4,16,128,16,1,true,true}, + {256,256,8,128,32,8,1,true,true}, + {256,128,2,128,32,16,1,true,true}, + {64,64,4,32,16,16,1,true,true}, + {16,128,8,16,16,16,1,true,true}, + {64,256,4,64,64,8,1,true,true}, + {64,32,4,16,32,16,1,true,true}, + {16,32,4,16,16,8,1,true,true}, + {16,16,4,16,16,8,1,true,true}, + {128,128,4,128,16,16,1,true,true} +}; +// END_GEMM_Wmma_i8_DEFS + +// BEGIN_CONV_Wmma_i8_DEFS +const InitParamsAccel PopulateParamsWmma::initParametersForward8BitConv[PopulateParamsWmma::nInitParametersForward8BitConv] = { + {128,64,8,32,64,16,1,true,true}, + {32,64,4,32,32,16,1,true,true}, + {64,64,4,64,16,16,1,true,true}, + {256,64,8,32,64,16,1,true,true}, + {128,32,4,64,16,16,1,true,true}, + {256,32,8,32,16,4,1,true,true}, + {32,256,4,32,16,4,1,true,true}, + {128,128,4,128,16,16,1,true,true}, + {64,16,8,32,16,16,1,true,true}, + {128,128,2,128,32,16,1,true,true}, + {256,128,2,32,128,16,1,true,true} +}; +// END_CONV_Wmma_i8_DEFS + +#endif + +#ifdef Wmma_DECLARATIONS_GEN + +// BEGIN_GEMM_Wmma_f16_DECS +static constexpr size_t nInitParametersFp16Gemm = 17; +static const InitParamsAccel initParametersFp16Gemm[nInitParametersFp16Gemm]; +// END_GEMM_Wmma_f16_DECS + +// BEGIN_CONV_Wmma_f16_DECS +static constexpr size_t nInitParametersFp16Conv = 26; +static const InitParamsAccel initParametersFp16Conv[nInitParametersFp16Conv]; +// END_CONV_Wmma_f16_DECS + +// BEGIN_GEMM_Wmma_i8_DECS +static constexpr size_t nInitParametersForward8BitGemm = 15; +static const InitParamsAccel initParametersForward8BitGemm[nInitParametersForward8BitGemm]; +// END_GEMM_Wmma_i8_DECS + +// BEGIN_CONV_Wmma_i8_DECS +static constexpr size_t nInitParametersForward8BitConv = 11; +static const InitParamsAccel initParametersForward8BitConv[nInitParametersForward8BitConv]; +// END_CONV_Wmma_i8_DECS + +#endif + diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index eec87ee51444..5e1342cf1831 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -37,39 +37,9 @@ llvm::raw_ostream &mlir::rock::operator<<(llvm::raw_ostream &os, /// Non-xdlops // clang-format off -const InitParamsNonAccel -PopulateParams::initParameters[PopulateParams::nInitParameters] = { - // blockSize M/block N/block K/block M/thread N/thread splitKFactor - {256, 128, 128, 16, 4, 4, 1}, - {256, 128, 128, 8, 4, 4, 1}, - {256, 128, 128, 4, 4, 4, 1}, - {256, 64, 64, 16, 4, 4, 1}, - {256, 32, 128, 8, 2, 2, 1}, - {256, 32, 128, 8, 2, 2, 1}, - {128, 128, 64, 16, 4, 4, 1}, - {128, 128, 64, 8, 4, 4, 1}, - {128, 128, 64, 4, 4, 4, 1}, - {128, 32, 32, 16, 2, 2, 1}, - {128, 64, 128, 16, 4, 4, 1}, - {128, 64, 128, 8, 4, 4, 1}, - {128, 64, 128, 4, 4, 4, 1}, - {128, 64, 64, 8, 2, 2, 1}, - {128, 64, 64, 16, 2, 4, 1}, - {128, 32, 32, 16, 2, 4, 1}, - {64, 64, 128, 16, 2, 4, 1}, - {64, 64, 128, 4, 2, 2, 1}, - {64, 64, 64, 16, 4, 4, 1}, - {64, 64, 64, 8, 4, 4, 1}, - {64, 64, 64, 4, 4, 4, 1}, - {64, 64, 32, 16, 4, 2, 1}, - {64, 64, 32, 8, 4, 2, 1}, - {64, 64, 32, 4, 4, 2, 1}, - {64, 32, 64, 16, 2, 4, 1}, - {64, 32, 64, 8, 2, 4, 1}, - {64, 32, 64, 4, 2, 4, 1}, - {64, 32, 32, 16, 2, 2, 1}, - {64, 32, 32, 8, 2, 2, 1}, - {64, 32, 32, 4, 2, 2, 1}}; +#define NonAccel_DEFINITIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef NonAccel_DEFINITIONS_GEN // clang-format on PopulateParamsInfo PopulateParamsInfo::fromOp(RockGemmWrapperInterface op) { @@ -303,7 +273,12 @@ PopulateParams::obtainTuningParameters(RockGemmWrapperInterface op, std::vector PopulateParams::getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB) const { - ArrayRef params = {initParameters, nInitParameters}; + ArrayRef params; + if (opType == KernelType::Gemm) { + params = {initParametersGemm, nInitParametersGemm}; + } else { + params = {initParametersConv, nInitParametersConv}; + } return std::vector(params); } @@ -433,140 +408,9 @@ PopulateParamsAccel::obtainTuningParameters(RockGemmWrapperInterface op, /// Xdlops acceleration // clang-format off -const InitParamsAccel -PopulateParamsXDL::initParameters[PopulateParamsXDL::nInitParameters] = { - // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore - {256, 256, 2, 128, 32, 4, 1, true, true}, - {256, 64, 8, 128, 32, 1, 1, true, true}, - {128, 128, 8, 64, 16, 4, 1, true, true}, - {128, 128, 4, 128, 32, 4, 1, true, true}, - {128, 128, 2, 32, 32, 8, 1, true, true}, - {128, 64, 8, 64, 16, 1, 1, true, true}, - {128, 64, 8, 32, 32, 4, 1, true, true}, - {128, 64, 8, 32, 16, 1, 1, true, true}, - {128, 64, 4, 32, 32, 4, 1, true, true}, - {128, 64, 2, 128, 32, 4, 1, true, true}, - {128, 32, 4, 128, 16, 4, 1, true, true}, - {128, 16, 4, 32, 16, 8, 1, true, true}, - {64, 256, 8, 64, 16, 4, 1, true, true}, - {64, 128, 4, 64, 32, 1, 1, true, true}, - {64, 128, 4, 64, 16, 4, 1, true , true}, - {64, 128, 4, 32, 16, 4, 1, true, true}, - {64, 128, 2, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 4, 1, true, true}, - {64, 64, 8, 16, 16, 4, 1, true, true}, - {64, 64, 8, 32, 16, 4, 1, true, true}, - {64, 64, 8, 16, 16, 8, 1, true, true}, - {64, 64, 4, 32, 16, 4, 1, true, true}, - {64, 64, 4, 16, 16, 8, 1, true, true}, - {64, 64, 8, 64, 16, 8, 1, true, true}, - {64, 32, 4, 32, 16, 8, 1, true, true}, - {64, 32, 8, 16, 16, 4, 1, true, true}, - {64, 32, 8, 16, 16, 4, 1, true, true}, - {64, 16, 8, 16, 16, 8, 1, true, true}, - {32, 128, 8, 32, 16, 1, 1, true, true}, - {32, 128, 8, 16, 16, 4, 1, true , true}, - {32, 64, 8, 32, 16, 4, 1, true, true}, - {32, 64, 4, 32, 16, 4, 1, true, true}, - {32, 32, 8, 16, 16, 8, 1, true, true}, - {32, 32, 8, 16, 16, 4, 1, true, true}, - {32, 16, 8, 16, 16, 8, 1, true, true}, - {32, 16, 4, 16, 16, 8, 1, true, true}, - {16, 32, 4, 16, 16, 4, 1, true, true}, - {16, 32, 8, 16, 16, 8, 1, true, true}, - {16, 16, 4, 16, 16, 4, 1, true, true}, - {16, 16, 8, 16, 16, 8, 1, true, true} -}; - -const InitParamsAccel -PopulateParamsXDL::initParametersFp16[PopulateParamsXDL::nInitParametersFp16] = { - // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore - {128, 256, 8, 64, 32, 4, 1, true, true}, - {128, 256, 4, 64, 32, 8, 1, true, true}, - {128, 128, 8, 128, 32, 8, 1, true, true}, - {128, 128, 8, 64, 32, 4, 1, true, true}, - {128, 128, 8, 32, 32, 8, 1, true, true}, - {128, 128, 8, 32, 16, 4, 1, true, true}, - {128, 128, 4, 128, 32, 8, 1, true, true}, - {128, 128, 4, 128, 16, 8, 1, true, true}, - {128, 128, 4, 64, 32, 8, 1, true, true}, - {128, 128, 4, 64, 16, 8, 1, true, true}, - {128, 128, 4, 32, 32, 8, 1, true, true}, - {128, 64, 4, 128, 16, 8, 1, true, true}, - {128, 64, 4, 32, 32, 8, 1, true, true}, - {128, 32, 8, 32 ,32 ,8, 1, true, true}, - {64, 128, 4, 64, 16, 8, 1, true, true}, - {64, 128, 8, 32, 32, 4, 1, true, true}, - {64, 128, 8, 32, 16, 8, 1, true, true}, - {64, 128, 8, 32, 16, 4, 1, true, true}, - {64, 128, 8, 64, 32, 4, 1, true, true}, - {64, 128, 4, 32, 16, 8, 1, true, true}, - {64, 128, 4, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 16, 8, 1, true, true}, - {64, 64, 8, 16, 16, 8, 1, true, true}, - {64, 64, 4, 32, 32, 8, 1, true, true}, - {64 ,64, 2, 32, 32, 4, 1, true, true}, - {64, 32, 8, 32, 32, 8, 1, true, true}, - {64, 32, 8, 32, 16, 8, 1, true, true}, - {64, 16, 8, 16, 16, 8, 1, true, true}, - {32, 128, 8, 32, 32, 4, 1, true, true}, - {32, 64, 8, 32, 32, 8, 1, true, true}, - {32, 64, 8, 32, 16, 4, 1, true, true}, - {32, 32, 8, 32, 32, 4, 1, true, true}, - {32, 32, 8, 16, 16, 8, 1, true, true}, - {32, 16 ,8, 16, 16, 8, 1, true, true}, - {16, 128, 4, 16, 16, 8, 1, true, true}, - {16, 32, 8, 16, 16, 8, 1, true, true}, - {16, 64, 8, 16, 16, 8, 1, true, true}, - {16, 32, 8, 16 ,16 ,4, 1, true, true} -}; - -const InitParamsAccel -PopulateParamsXDL::initParametersForward8Bit[ - PopulateParamsXDL::nInitParametersForward8Bit] = { - {128, 256, 8, 128, 16, 4, 1, true, true}, - {128, 128, 16, 64, 32, 8, 1, true, true}, - {128, 128, 8, 128, 16, 8, 1, true, true}, - {128, 128, 8, 64, 16, 8, 1, true, true}, - {128, 128, 8, 32, 16, 16, 1, true, true}, - {128, 64, 32, 64, 32, 4, 1, true, true}, - {128, 64, 8, 32, 32, 16, 1, true, true}, - {128, 64, 8, 32, 16, 16, 1, true, true}, - {128, 64, 4, 32, 16, 16, 1, true, true}, - {64, 128, 32, 64, 32, 4, 1, true, true}, - {64, 128, 16, 32, 16, 4, 1, true, true}, - {64, 128, 8, 64, 16, 8, 1, true, true}, - {64, 128, 4, 32, 16, 16, 1, true , true}, - {64, 128, 8, 32, 16, 8, 1, true, true}, - {64, 64, 16, 32, 32, 4, 1, true, true}, - {64, 64, 8, 32, 32, 16, 1, true, true}, - {64, 64, 8, 32, 16, 16, 1, true, true}, - {64, 64, 4, 32, 16, 16, 1, true, true}, - {64, 64, 4, 32, 16, 8, 1, true, true}, - {64, 64, 16, 32, 16, 4, 1, true, true}, - {64, 64, 16, 16, 16, 16, 1, true, true}, - {64, 32, 16, 32, 16, 4, 1, true, true}, - {64, 32, 8, 16, 16, 16, 1, true, true}, - {64, 32, 8, 32, 16, 16, 1, true, true}, - {64, 32, 8, 32, 16, 8, 1, true, true}, - {64, 16, 8, 16, 16, 16, 1, true, true}, - {32, 256, 16, 32, 32, 4, 1, true, true}, - {32, 256, 4, 32, 16, 8, 1, true, true}, - {32, 128, 32, 32, 16, 4, 1, true, true}, - {32, 64, 32, 16, 16, 4, 1, true, true}, - {32, 64, 16, 32, 16, 4, 1, true, true}, - {32, 64, 8, 16, 16, 16, 1, true, true}, - {32, 64, 4, 32, 16, 8, 1, true, true}, - {32, 32, 32, 16, 16, 4, 1, true, true}, - {32, 32, 16, 16, 16, 8, 1, true, true}, - {32, 16, 16, 16, 16, 8, 1, true, true}, - {16, 64, 16, 16, 16, 4, 1, true, true}, - {16, 32, 16, 16, 16, 16, 1, true, true}, - {16, 16, 32, 16, 16, 4, 1, true, true}, - {16, 16, 16, 16, 16, 4, 1, true, true} -}; +#define XDL_DEFINITIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef XDL_DEFINITIONS_GEN // clang-format on LogicalResult PopulateParamsXDL::isValidBlockwiseGemm( @@ -700,15 +544,28 @@ std::vector PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB, StringRef arch) const { ArrayRef params; - switch (dataTypeA.getIntOrFloatBitWidth()) { - case 8: - params = {initParametersForward8Bit, nInitParametersForward8Bit}; - break; - case 16: - params = {initParametersFp16, nInitParametersFp16}; - break; - default: - params = {initParameters, nInitParameters}; + if (opType == KernelType::Gemm) { + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitGemm, nInitParametersForward8BitGemm}; + break; + case 16: + params = {initParametersFp16Gemm, nInitParametersFp16Gemm}; + break; + default: + params = {initParametersGemm, nInitParametersGemm}; + } + } else { + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitConv, nInitParametersForward8BitConv}; + break; + case 16: + params = {initParametersFp16Conv, nInitParametersFp16Conv}; + break; + default: + params = {initParametersConv, nInitParametersConv}; + } } std::vector res; // Only return valid XDLOp params @@ -752,75 +609,9 @@ PopulateParamsXDL::getGemmParamsAttr(OpBuilder &builder, /// Wmma acceleration // clang-format off -const InitParamsAccel -PopulateParamsWmma::initParametersFp16[PopulateParamsWmma::nInitParametersFp16] = { - // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore - {256, 64, 4, 64, 64, 8, 1, true, true}, - {256, 32, 8, 64, 32, 8, 1, true, true}, - {256, 16, 8, 64, 16, 8, 1, true, true}, - {128, 128, 8, 64, 64, 8, 1, true, true}, - {128, 128, 4, 64, 64, 8, 1, true, true}, - {128, 128, 4, 64, 64, 16, 1, true, true}, - {128, 128, 2, 64, 64, 8, 1, true, true}, - {128, 64, 8, 32, 64, 8, 1, true, true}, - {128, 64, 4, 64, 64, 8, 1, true, true}, - {128, 64, 4, 64, 32, 8, 1, true, true}, - {128, 64, 4, 32, 64, 8, 1, true, true}, - {128, 32, 4, 32, 32, 8, 1, true, true}, - {128, 16, 8, 32, 16, 8, 1, true, true}, - {64, 256, 4, 64, 64, 8, 1, true, true}, - {64, 256, 2, 64, 64, 8, 1, true, true}, - {64, 128, 4, 64, 32, 8, 1, true, true}, - {64, 128, 4, 64, 64, 8, 1, true, true}, - {64, 128, 4, 32, 64, 8, 1, true, true}, - {64, 128, 2, 64, 64, 8, 1, true, true}, - {64, 64, 4, 32, 32, 8, 1, true, true}, - {64, 32, 4, 32, 32, 8, 1, true, true}, - {64, 16, 8, 16, 16, 8, 1, true, true}, - {32, 32, 8, 32, 16, 16, 1, true, true}, - {32, 32, 8, 32, 16, 8, 1, true, true}, - {32, 32, 8, 16, 16, 8, 1, true, true}, - {16, 256, 4, 16, 64, 4, 1, true, true}, - {16, 32, 8, 16, 16, 16, 1, true, true}, - {16, 32, 8, 16, 16, 8, 1, true, true}, - {16, 16, 8, 16, 16, 8, 1, true, true}, - {16, 16, 2, 16, 16, 8, 1, true, true}, -}; - -const InitParamsAccel -PopulateParamsWmma::initParametersForward8Bit[ - PopulateParamsWmma::nInitParametersForward8Bit] = { - {128, 128, 8, 64, 64, 16, 1, true, true}, - {128, 128, 4, 64, 64, 16, 1, true, true}, - {128, 128, 2, 64, 64, 16, 1, true, true}, - {128, 32, 8, 32, 32, 16, 1, true, true}, - {128, 32, 4, 32, 32, 16, 1, true, true}, - {128, 32, 2, 64, 16, 16, 1, true, true}, - {128, 64, 4, 64, 32, 16, 1, true, true}, - {128, 64, 4, 32, 64, 16, 1, true, true}, - {128, 32, 4, 32, 32, 16, 1, true, true}, - {128, 32, 2, 64, 16, 16, 1, true, true}, - {128, 16, 4, 32, 16, 16, 1, true, true}, - {64, 256, 4, 64, 64, 16, 1, true, true}, - {64, 128, 4, 32, 64, 16, 1, true, true}, - {64, 128, 4, 64, 32, 16, 1, true, true}, - {64, 64, 4, 16, 64, 16, 1, true, true}, - {64, 32, 4, 16, 32, 16, 1, true, true}, - {64, 32, 4, 32, 16, 16, 1, true, true}, - {64, 32, 2, 16, 32, 16, 1, true, true}, - {64, 16, 8, 16, 16, 16, 1, true, true}, - {32, 256, 4, 32, 64, 4, 1, true, true}, - {32, 128, 8, 32, 32, 4, 1, true, true}, - {32, 128, 4, 32, 32, 16, 1, true, true}, - {32, 32, 8, 16, 16, 4, 1, true, true}, - {16, 256, 8, 16, 64, 16, 1, true, true}, - {16, 64, 8, 16, 16, 16, 1, true, true}, - {16, 64, 8, 16, 64, 16, 1, true, true}, - {16, 16, 8, 16, 16, 16, 1, true, true}, - {16, 16, 8, 16, 16, 4, 1, true, true}, - {16, 16, 2, 16, 16, 8, 1, true, true}, - {16, 64, 2, 16, 16, 16, 1, true, true}, -}; +#define Wmma_DEFINITIONS_GEN +#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" +#undef Wmma_DEFINITIONS_GEN // clang-format on LogicalResult PopulateParamsWmma::isValidBlockwiseGemm( @@ -921,15 +712,28 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB, StringRef arch) const { ArrayRef params; std::vector res; - switch (dataTypeA.getIntOrFloatBitWidth()) { - case 8: - params = {initParametersForward8Bit, nInitParametersForward8Bit}; - break; - case 16: - params = {initParametersFp16, nInitParametersFp16}; - break; - default: - return res; + if (opType == KernelType::Gemm) { + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitGemm, nInitParametersForward8BitGemm}; + break; + case 16: + params = {initParametersFp16Gemm, nInitParametersFp16Gemm}; + break; + default: + return res; + } + } else { + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitConv, nInitParametersForward8BitConv}; + break; + case 16: + params = {initParametersFp16Conv, nInitParametersFp16Conv}; + break; + default: + return res; + } } // Only return valid Wmma params const int64_t waveSize = mlir::rock::lookupArchInfo(arch).waveSize; diff --git a/mlir/test/Dialect/Rock/affix_tuning_params.mlir b/mlir/test/Dialect/Rock/affix_tuning_params.mlir index 4211056337f7..499785a59a39 100644 --- a/mlir/test/Dialect/Rock/affix_tuning_params.mlir +++ b/mlir/test/Dialect/Rock/affix_tuning_params.mlir @@ -10,9 +10,9 @@ // GRID-LABEL: rock_conv func.func @rock_conv(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x32x32xf32>, %output : memref<128x1x128x30x30xf32>) { // CHECK: rock.conv - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 900 + // GRID-SAME: gridSize = 3600 rock.conv(%filter, %input, %output) features = none { arch = "amdgcn-amd-amdhsa:gfx906", filter_layout = ["g", "k", "c", "0", "1"], @@ -29,9 +29,9 @@ func.func @rock_conv(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x // GRID-LABEL: func.func @rock_conv_f16 func.func @rock_conv_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x1x8x32x32xf16>, %output : memref<128x1x128x30x30xf16>) { // CHECK: rock.conv - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 900 + // GRID-SAME: gridSize = 3600 rock.conv(%filter, %input, %output) features = none { arch = "amdgcn-amd-amdhsa:gfx906", filter_layout = ["g", "k", "c", "0", "1"], @@ -49,9 +49,9 @@ func.func @rock_conv_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x func.func @rock_conv_i8(%filter : memref<1x128x8x3x3xi8>, %input : memref<128x1x8x32x32xi8>, %output : memref<128x1x128x30x30xi32>) { // CHECK: rock.conv // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 450 + // GRID-SAME: gridSize = 900 rock.conv(%filter, %input, %output) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", filter_layout = ["g", "k", "c", "0", "1"], @@ -69,9 +69,9 @@ func.func @rock_conv_i8(%filter : memref<1x128x8x3x3xi8>, %input : memref<128x1x func.func @rock_conv_bwd_data(%filter: memref<1x1024x1024x1x1xf32>, %input: memref<128x1x1024x14x14xf32>, %output: memref<128x1x1024x14x14xf32>) attributes {kernel = 0 : i32} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 392 + // GRID-SAME: gridSize = 6272 rock.conv_bwd_data(%filter, %input, %output) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", dilations = [1 : index, 1 : index], @@ -90,9 +90,9 @@ func.func @rock_conv_bwd_data(%filter: memref<1x1024x1024x1x1xf32>, %input: memr func.func @rock_conv_bwd_data_f16(%filter: memref<1x1024x1024x1x1xf16>, %input: memref<128x1x1024x14x14xf16>, %output: memref<128x1x1024x14x14xf16>) attributes {kernel = 0 : i32} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 784 + // GRID-SAME: gridSize = 1568 rock.conv_bwd_data(%filter, %input, %output) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", dilations = [1 : index, 1 : index], @@ -130,7 +130,7 @@ func.func @rock_conv_bwd_data_padMN(%filter : memref<1x64x3x1x1xf32>, %input : m // GRID-LABEL: @rock_conv_bwd_data_padMK func.func @rock_conv_bwd_data_padMK(%filter : memref<1x11x3x1x1xf32>, %input : memref<128x1x3x15x15xf32>, %output : memref<128x1x11x15x15xf32>) { // CHECK: rock.conv_bwd_data - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 450 rock.conv_bwd_data(%filter, %input, %output) features = none { @@ -150,9 +150,9 @@ func.func @rock_conv_bwd_data_padMK(%filter : memref<1x11x3x1x1xf32>, %input : m // GRID-LABEL: @rock_conv_bwd_weight func.func @rock_conv_bwd_weight(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x32x32xf32>, %output : memref<128x1x128x30x30xf32>) { // CHECK: rock.conv_bwd_weight - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 12 + // GRID-SAME: gridSize = 3 rock.conv_bwd_weight(%filter, %input, %output) features = none { arch = "amdgcn-amd-amdhsa:gfx906", numCU = 64 : i32, @@ -170,9 +170,9 @@ func.func @rock_conv_bwd_weight(%filter : memref<1x128x8x3x3xf32>, %input : memr // GRID-LABEL: @rock_conv_bwd_weight_f16 func.func @rock_conv_bwd_weight_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x1x8x32x32xf16>, %output : memref<128x1x128x30x30xf16>) { // CHECK: rock.conv_bwd_weight - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 12 + // GRID-SAME: gridSize = 3 rock.conv_bwd_weight(%filter, %input, %output) features = none { arch = "amdgcn-amd-amdhsa:gfx906", numCU = 64 : i32, @@ -190,7 +190,7 @@ func.func @rock_conv_bwd_weight_f16(%filter : memref<1x128x8x3x3xf16>, %input : // GRID-LABEL: func.func @rock_conv_bwd_weight_padALL func.func @rock_conv_bwd_weight_padALL(%filter : memref<1x20x8x3x3xf32>, %input : memref<7x1x8x32x32xf32>, %output : memref<7x1x20x30x30xf32>) { // CHECK: rock.conv_bwd_weight - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 3 rock.conv_bwd_weight(%filter, %input, %output) features = none { @@ -210,7 +210,7 @@ func.func @rock_conv_bwd_weight_padALL(%filter : memref<1x20x8x3x3xf32>, %input // GRID-LABEL: @rock_conv_bwd_weight_padALL_f16 func.func @rock_conv_bwd_weight_padALL_f16(%filter : memref<1x20x8x3x3xf16>, %input : memref<7x1x8x32x32xf16>, %output : memref<7x1x20x30x30xf16>) { // CHECK: rock.conv_bwd_weight - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 3 rock.conv_bwd_weight(%filter, %input, %output) features = none { @@ -254,9 +254,9 @@ func.func @rock_conv_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256 func.func @rock_conv_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) { // CHECK: rock.conv // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 25088 + // GRID-SAME: gridSize = 12544 rock.conv(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", dilations = [1 : index, 1 : index], @@ -273,10 +273,10 @@ func.func @rock_conv_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x23 // GRID-LABEL: @rock_conv_bwd_weight_7x7 func.func @rock_conv_bwd_weight_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {kernel = 0 : i32} { // CHECK: rock.conv_bwd_weight - // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: derivedBlockSize = 128 + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 5 + // GRID-SAME: gridSize = 20 rock.conv_bwd_weight(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", dilations = [1 : index, 1 : index], @@ -316,10 +316,10 @@ func.func @rock_conv_bwd_data_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: m // GRID-LABEL: @rock_conv_bwd_data_7x7 func.func @rock_conv_bwd_data_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {kernel = 1 : i32} { // CHECK: rock.conv_bwd_data - // CHECK-SAME: derivedBlockSize = 128 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: derivedBlockSize = 64 + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 105800 + // GRID-SAME: gridSize = 211600 rock.conv_bwd_data(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add { arch = "amdgcn-amd-amdhsa:gfx908", dilations = [1 : index, 1 : index], @@ -337,9 +337,9 @@ func.func @rock_conv_bwd_data_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<2 // GRID-LABEL: @rock_gemm_from_conv func.func @rock_gemm_from_conv(%a : memref<1x72x128xf32>, %b : memref<1x72x115200xf32>, %c : memref<1x128x115200xf32>) { // CHECK: rock.gemm - // CHECK-SAME: params = #rock.general_gemm_params + // CHECK-SAME: params = #rock.general_gemm_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 900 + // GRID-SAME: gridSize = 7200 rock.gemm %c = tr %a * %b features = none storeMethod = set { arch = "amdgcn-amd-amdhsa:gfx906", numCU = 64 : i32 @@ -352,9 +352,9 @@ func.func @rock_gemm_from_conv(%a : memref<1x72x128xf32>, %b : memref<1x72x11520 func.func @rock_gemm_from_i8_conv(%a : memref<1x72x128xi8>, %b : memref<1x72x115200xi8>, %c : memref<1x128x115200xi32>) { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 450 + // GRID-SAME: gridSize = 3600 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add storeMethod = set { arch = "amdgcn-amd-amdhsa:gfx908", numCU = 120 : i32 @@ -370,9 +370,9 @@ func.func @rock_gemm_from_i8_conv(%a : memref<1x72x128xi8>, %b : memref<1x72x115 func.func @rock_gemm_from_i8_conv_gfx940(%a : memref<1x72x128xi8>, %b : memref<1x72x115200xi8>, %c : memref<1x128x115200xi32>) { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 3600 + // GRID-SAME: gridSize = 1800 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add storeMethod = set { arch = "amdgcn-amd-amdhsa:gfx940", numCU = 120 : i32 @@ -386,9 +386,9 @@ func.func @rock_gemm_from_i8_conv_gfx940(%a : memref<1x72x128xi8>, %b : memref<1 func.func @rock_gemm_xdlops_fp8_bf8(%a : memref<1x72x128xf8E4M3FNUZ>, %b : memref<1x72x115200xf8E5M2FNUZ>, %c : memref<1x128x115200xf32>) { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 3600 + // GRID-SAME: gridSize = 1800 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add storeMethod = set { arch = "amdgcn-amd-amdhsa:gfx940", numCU = 120 : i32 @@ -402,9 +402,9 @@ func.func @rock_gemm_xdlops_fp8_bf8(%a : memref<1x72x128xf8E4M3FNUZ>, %b : memre func.func @rock_gemm_xdlops_fp8_bf8_ocp(%a : memref<1x72x128xf8E4M3FN>, %b : memref<1x72x115200xf8E5M2>, %c : memref<1x128x115200xf32>) { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.xdlops_gemm_derived_params + // CHECK-SAME: params = #rock.xdlops_gemm_derived_params // GRID: rock.gridwise_gemm - // GRID-SAME: gridSize = 3600 + // GRID-SAME: gridSize = 1800 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add storeMethod = set { arch = "amdgcn-amd-amdhsa:gfx940", numCU = 120 : i32 diff --git a/mlir/test/Dialect/Rock/test_packed_arithmetic.mlir b/mlir/test/Dialect/Rock/test_packed_arithmetic.mlir index 1e266c9b1629..eac3cfea1aa2 100644 --- a/mlir/test/Dialect/Rock/test_packed_arithmetic.mlir +++ b/mlir/test/Dialect/Rock/test_packed_arithmetic.mlir @@ -18,15 +18,12 @@ // VECTORIZE: vector.transfer_write %[[trunc]] // ROCDL: %[[pkrtz:.*]] = rocdl.cvt.pkrtz {{.*}}, {{.*}} : vector<2xf16> // ROCDL: llvm.store %[[pkrtz]], {{.*}} : vector<2xf16>, !llvm.ptr<5> -// LLVM: %[[extract0:.*]] = extractelement <16 x float> {{.*}}, i64 0 -// LLVM: %[[extract1:.*]] = extractelement <16 x float> {{.*}}, i64 1 +// LLVM: %[[extract0:.*]] = extractelement <4 x float> {{.*}}, i64 0 +// LLVM: %[[extract1:.*]] = extractelement <4 x float> {{.*}}, i64 1 // LLVM: tail call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %[[extract0]], float %[[extract1]]) -// LLVM: %[[extract2:.*]] = extractelement <16 x float> {{.*}}, i64 2 -// LLVM: %[[extract3:.*]] = extractelement <16 x float> {{.*}}, i64 3 +// LLVM: %[[extract2:.*]] = extractelement <4 x float> {{.*}}, i64 2 +// LLVM: %[[extract3:.*]] = extractelement <4 x float> {{.*}}, i64 3 // LLVM: tail call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %[[extract2]], float %[[extract3]]) -// LLVM: %[[extract14:.*]] = extractelement <16 x float> {{.*}}, i64 14 -// LLVM: %[[extract15:.*]] = extractelement <16 x float> {{.*}}, i64 15 -// LLVM: tail call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %[[extract14]], float %[[extract15]]) // ASM: v_pk_add_f16 {{.*}}, {{.*}}, {{.*}} module { func.func @test_fusion(%arg0: memref<1x128x128xf16> {mhal.read_access}, %arg1: memref<1x128x128xf16> {mhal.read_access}, %arg2: memref<1x128x128xf16> {mhal.read_access}, %arg3: memref<1x128x128xf16> {mhal.write_access}) attributes {arch = "gfx942", kernel} { diff --git a/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir b/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir index 95bad400cdfc..640e604487cd 100644 --- a/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir +++ b/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir @@ -7,16 +7,16 @@ func.func @test_gemm_reduce_last_axis_fusion(%arg0: memref<1x128x64xf32>, %arg1: // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> - // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> - // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> - // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<16x2x64x8x32xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<16x2x64x8x32xf32> to memref<16x2x8x1x64x32xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<16x2x8x1x64x32xf32> to memref<16x2x8x64x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<16x2x8x64x1xf32> to memref<16x2x8x64x32xf32> - // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32 - // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x2x2x4x4x4x2x2x2xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<16x2x8x2x2x4x4x4x2x2x2xf32> to memref<16x2x8x64x32xf32> // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %0 into %arg2 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> @@ -31,16 +31,16 @@ func.func @test_gemm_reduce_middle_axis_fusion(%arg0: memref<1x128x64xf32>, %arg // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x1x256xf32> to memref<1x128x256xf32> - // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> - // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> - // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x1x128xf32> - // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim0"] at [3]>{{.*}} : memref<2x1x2x1x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<16x2x64x8x32xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<16x2x64x8x32xf32> to memref<16x2x8x1x64x32xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<16x2x8x1x64x32xf32> to memref<16x2x8x1x32xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim0"] at [3]>{{.*}} : memref<16x2x8x1x32xf32> to memref<16x2x8x64x32xf32> - // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> - // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x2x2x4x4x4x2x2x2xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<16x2x8x2x2x4x4x4x2x2x2xf32> to memref<16x2x8x64x32xf32> // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %0 into %arg2 features = mfma|dot|atomic_add {axis = 1 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> @@ -61,16 +61,16 @@ func.func @test_gemm_add_reduce_fusion(%arg0: memref<1x128x64xf32>, %arg1: memre // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] // CHECK: %[[TR0:.+]] = rock.transform %arg3 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> - // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> - // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> - // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<16x2x64x8x32xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<16x2x64x8x32xf32> to memref<16x2x8x1x64x32xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<16x2x8x1x64x32xf32> to memref<16x2x8x64x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<16x2x8x64x1xf32> to memref<16x2x8x64x32xf32> - // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> - // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x2x2x4x4x4x2x2x2xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<16x2x8x2x2x4x4x4x2x2x2xf32> to memref<16x2x8x64x32xf32> // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %1 into %arg3 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> @@ -84,16 +84,16 @@ func.func @test_gemm_reduce_max(%arg0: memref<1x128x64xf32>, %arg1: memref<1x64x // CHECK: rock.blockwise_broadcast_reduce max {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> - // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> - // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> - // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> - - // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> - // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> - // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<16x2x64x8x32xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<16x2x64x8x32xf32> to memref<16x2x8x1x64x32xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<16x2x8x1x64x32xf32> to memref<16x2x8x64x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<16x2x8x64x1xf32> to memref<16x2x8x64x32xf32> + + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x64x32xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<16x2x8x64x32xf32> to memref<16x2x8x2x2x4x4x4x2x2x2xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<16x2x8x2x2x4x4x4x2x2x2xf32> to memref<16x2x8x64x32xf32> // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_max : {{.*}} rock.reduce max %0 into %arg2 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> diff --git a/mlir/test/fusion/tosa-to-rock-gemm-reshape-add.mlir b/mlir/test/fusion/tosa-to-rock-gemm-reshape-add.mlir index 5d402f1c2c76..5def84e9c2f9 100644 --- a/mlir/test/fusion/tosa-to-rock-gemm-reshape-add.mlir +++ b/mlir/test/fusion/tosa-to-rock-gemm-reshape-add.mlir @@ -1,13 +1,13 @@ // RUN: rocmlir-driver --host-pipeline highlevel %s | rocmlir-opt --rock-affix-params --rock-conv-to-gemm --rock-gemm-to-gridwise -rock-regularize -rock-gridwise-gemm-to-blockwise -rock-linalg-align | FileCheck %s --check-prefix=CHECK_LINALG_ALIGN -// CHECK_LINALG_ALIGN-DAG: #[[AMAP:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1, d2)> -// CHECK_LINALG_ALIGN-DAG: #[[AMAP1:.*]] = affine_map<(d0, d1) -> (d0 * 1000 + d1)> -// CHECK_LINALG_ALIGN-DAG: #[[MAP1:.*]] = #rock.transform_map<#[[AMAP]] by [ ["dim0"] at [0]>, ["dim1"] at [1]>] bounds = [1, 1, 1000] -> [1, 1000]> -// CHECK_LINALG_ALIGN-DAG: #[[MAP2:.*]] = #rock.transform_map<#[[AMAP1]] by [ ["dim0"] at [0]>] bounds = [1, 1000] -> [1000]> +// CHECK_LINALG_ALIGN-DAG: #[[AMAP:.*]] = affine_map<(d0, d1) -> (d0 * 1000 + d1)> +// CHECK_LINALG_ALIGN-DAG: #[[AMAP1:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1, d2)> +// CHECK_LINALG_ALIGN-DAG: #[[MAP1:.*]] = #rock.transform_map<#[[AMAP]] by [ ["dim0"] at [0]>] bounds = [1, 1000] -> [1000]> +// CHECK_LINALG_ALIGN-DAG: #[[MAP2:.*]] = #rock.transform_map<#[[AMAP1]] by [ ["dim0"] at [0]>, ["dim1"] at [1]>] bounds = [1, 1, 1000] -> [1, 1000]> // CHECK_LINALG_ALIGN-COUNT-2: rock.threadwise_read_into {{.*}} // CHECK_LINALG_ALIGN: rock.threadwise_read_into {{.*}} -> [[lain:%.*]] : -// CHECK_LINALG_ALIGN: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<16xf32, #gpu.address_space>) +// CHECK_LINALG_ALIGN: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<64xf32, #gpu.address_space>) // CHECK_LINALG_ALIGN: rock.threadwise_write_all {{.*}} %[[outBuf]] -> // to test reshape is converted as transform and fused. diff --git a/mlir/test/fusion/tosa-to-rock-tp-add-tp.mlir b/mlir/test/fusion/tosa-to-rock-tp-add-tp.mlir index 5fe73d78d175..7b1c499461f0 100644 --- a/mlir/test/fusion/tosa-to-rock-tp-add-tp.mlir +++ b/mlir/test/fusion/tosa-to-rock-tp-add-tp.mlir @@ -3,7 +3,7 @@ // CHECK-DAG: #[[MAP2:.*]] = #rock.transform_map<{{.*}} by [ ["dim0", "dim2", "dim3", "dim1"] at [0, 2, 3, 1]>] bounds = [256, 28, 28, 64] -> [256, 64, 28, 28]> // CHECK-COUNT-2: rock.threadwise_read_into {{.*}} // CHECK: rock.threadwise_read_into {{.*}} -> [[lain:%.*]] : -// CHECK: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<16xf32, #gpu.address_space>) +// CHECK: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<128xf32, #gpu.address_space>) // CHECK: rock.threadwise_write_all {{.*}} %[[outBuf]] -> // to test transpose is converted as transform and fused. diff --git a/mlir/test/fusion/tosa-to-rock-tp-add.mlir b/mlir/test/fusion/tosa-to-rock-tp-add.mlir index b8b1f5109cad..fe5897d469e9 100644 --- a/mlir/test/fusion/tosa-to-rock-tp-add.mlir +++ b/mlir/test/fusion/tosa-to-rock-tp-add.mlir @@ -3,7 +3,7 @@ // CHECK-DAG: #[[MAP2:.*]] = #rock.transform_map<#map{{.*}} by [ ["{{.*}}", "{{.*}}", "{{.*}}", "{{.*}}"] at [0, 2, 3, 1]>] bounds = [256, 28, 28, 64] -> [256, 64, 28, 28]> // CHECK-COUNT-2: rock.threadwise_read_into {{.*}} // CHECK: rock.threadwise_read_into {{.*}} -> [[lain:%.*]] : -// CHECK: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<16xf32, #gpu.address_space>) +// CHECK: linalg.generic{{.*}} ins({{.*}}, [[lain]] :{{.*}}) outs(%[[outBuf:.*]] : memref<128xf32, #gpu.address_space>) // CHECK: rock.threadwise_write_all {{.*}} %[[outBuf]] -> // to test transpose is converted as transform and fused. diff --git a/mlir/utils/performance/analysis/quickTuningGen.py b/mlir/utils/performance/analysis/quickTuningGen.py new file mode 100644 index 000000000000..d8c2da10627d --- /dev/null +++ b/mlir/utils/performance/analysis/quickTuningGen.py @@ -0,0 +1,414 @@ +import os +import sys +import argparse + +import pulp +import numpy as np +import re +import glob +import pandas as pd +from sklearn.preprocessing import MinMaxScaler +from collections import defaultdict + + +class FileWriter(): + """ + A class to handle updating quick-tuning perfcofigs file. + """ + + def __init__(self, pargs): + self.op = pargs.op + self.arch = pargs.arch + + def parse_perfconfigs(self, perfconfig_str): + """ + Parses a perfconfigs in expected input format + """ + perfconfig_str = perfconfig_str.replace('v2:', '') + config_values = perfconfig_str.split(',') + total_values = len(config_values) + converted_values = [] + for idx, value in enumerate(config_values): + if idx > total_values - 3 and total_values == 9: + if value == '1': + converted_values.append('true') + else: + converted_values.append('false') + else: + converted_values.append(value) + + formated_perfconfig = ','.join(converted_values) + return formated_perfconfig + + def replace_section( + self, + file_path, + begin_marker, + end_marker, + new_content): + """ + Replaces a section between the markers with new content + """ + with open(file_path, 'r') as file: + content = file.read() + + pattern = re.compile(f'{begin_marker}.*?{end_marker}', re.DOTALL) + replacment = f'{begin_marker}\n{new_content}\n{end_marker}' + content = pattern.sub(replacment, content) + + with open(file_path, 'w') as file: + file.write(content) + + def isAccel(self, arch, datatype): + instruction_type = self.get_instruction_type(arch, datatype) + if instruction_type == "XDL" or instruction_type == "Wmma": + return True + else: + return False + + def get_instruction_type(self, arch, datatype): + """ + Determines the instruction type based on architecture and data type + """ + if arch.startswith("gfx9"): + return "XDL" + elif arch.startswith("gfx1") and datatype != "f32": + return "Wmma" + else: + return "NonAccel" + + def init_inc_file(self, file_path): + """ + Initialize an .inc file with predefined structure + """ + instruction_types_to_datatypes = { + "NonAccel": ["f32"], + "XDL": ["f32", "f16", "i8"], + "Wmma": ["f16", "i8"] + } + markers = [ + "// BEGIN_GEMM", + "// END_GEMM", + "// BEGIN_CONV", + "// END_CONV", + ] + + with open(file_path, 'w') as file: + file.write("// THIS IS AN AUTOGENERATED FILE.\n") + file.write("// DO NOT EDIT THIS FILE DIRECTLY!\n\n") + file.write("// clang-format off\n") + for instrction_type, datatypes in instruction_types_to_datatypes.items(): + file.write(f"#ifdef {instrction_type}_DEFINITIONS_GEN\n\n") + for datatype in datatypes: + for marker in markers: + file.write( + f"{marker}_{instrction_type}_{datatype}_DEFS\n\n") + file.write(f"#endif\n\n") + + file.write(f"#ifdef {instrction_type}_DECLARATIONS_GEN\n\n") + for datatype in datatypes: + for marker in markers: + file.write( + f"{marker}_{instrction_type}_{datatype}_DECS\n\n") + file.write(f"#endif\n\n") + + def get_init_params_definitions(self, arch, dtype, op): + """ + Generates initialization parameter definitions for a given data type and operation. + """ + accel_type = 'Accel' if self.isAccel(arch, dtype) else 'NonAccel' + instruction_type = self.get_instruction_type(arch, dtype) + op_cap = op.capitalize() + + if dtype == 'f32': + init_params = f"initParameters{op_cap}" + n_init_params = f"nInitParameters{op_cap}" + if not self.isAccel(arch, dtype): + instruction_type = '' + elif dtype == 'f16': + init_params = f"initParametersFp16{op_cap}" + n_init_params = f"nInitParametersFp16{op_cap}" + elif dtype == 'i8': + init_params = f"initParametersForward8Bit{op_cap}" + n_init_params = f"nInitParametersForward8Bit{op_cap}" + else: + raise ValueError("Unsupported dtype") + + return f"const InitParams{accel_type} PopulateParams{instruction_type}::{init_params}[PopulateParams{instruction_type}::{n_init_params}]" + + def get_init_params_declaration(self, arch, dtype, op): + """ + Generates initialization parameter declarations for a given data type and operation. + """ + op_cap = op.capitalize() + accel_type = 'Accel' if self.isAccel(arch, dtype) else 'NonAccel' + + if dtype == 'f32': + init_params = f"initParameters{op_cap}" + n_init_params = f"nInitParameters{op_cap}" + elif dtype == 'f16': + init_params = f"initParametersFp16{op_cap}" + n_init_params = f"nInitParametersFp16{op_cap}" + elif dtype == 'i8': + init_params = f"initParametersForward8Bit{op_cap}" + n_init_params = f"nInitParametersForward8Bit{op_cap}" + else: + raise ValueError("Unsupported dtype") + + return (f"static const InitParams{accel_type} {init_params}[{n_init_params}]", + f"static constexpr size_t {n_init_params}") + + + def update_config_file(self, result): + """ + Updates the configuration file with selected perfconfigs + """ + file_path = "../../../include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc" + if not os.path.exists(file_path): + self.init_inc_file(file_path) + + datatype_names_defs= { + 'f32': self.get_init_params_definitions(self.arch, 'f32', self.op), + 'f16': self.get_init_params_definitions(self.arch, 'f16', self.op), + 'i8': self.get_init_params_definitions(self.arch, 'i8', self.op) + } + + for datatype, perfconfigs in result.items(): + lines = [] + datatype_name = datatype_names_defs.get(datatype) + lines.append(f"{datatype_name} = {{") + for idx, perfconfig in enumerate(perfconfigs): + formated_perfconfig = self.parse_perfconfigs(perfconfig) + if idx == len(perfconfigs) - 1: + lines.append(f" {{{formated_perfconfig}}}") + else: + lines.append(f" {{{formated_perfconfig}}},") + lines.append("};") + + new_content = '\n'.join(lines) + self.replace_section( + file_path, + f"// BEGIN_{self.op.upper()}_{self.get_instruction_type(self.arch, datatype)}_{datatype}_DEFS", + f"// END_{self.op.upper()}_{self.get_instruction_type(self.arch, datatype)}_{datatype}_DEFS", + new_content) + + datatype_names_decs = {} + datatype_n_decs = {} + + for dtype in ['f32', 'f16', 'i8']: + init_params_dec, n_params_dec = self.get_init_params_declaration(self.arch, dtype, self.op) + datatype_names_decs[dtype] = init_params_dec + datatype_n_decs[dtype] = n_params_dec + + for datatype, perfconfigs in result.items(): + lines = [] + datatype_name_dec = datatype_names_decs.get(datatype) + datatype_n_dec = datatype_n_decs.get(datatype) + lines.append(f"{datatype_n_dec} = {len(perfconfigs)};") + lines.append(f"{datatype_name_dec};") + + new_content = '\n'.join(lines) + self.replace_section( + file_path, + f"// BEGIN_{self.op.upper()}_{self.get_instruction_type(self.arch, datatype)}_{datatype}_DECS", + f"// END_{self.op.upper()}_{self.get_instruction_type(self.arch, datatype)}_{datatype}_DECS", + new_content) + + +class PerfConfigsFinder(): + """ + A class to find optimal perfconfigs based on input data + """ + + def __init__(self, combined_data, pargs): + self.th = pargs.th + self.op = pargs.op + self.input_dir = pargs.input_dir + self.arch = pargs.arch + self.df = combined_data + + def get_unique_perfconfigs_list(self, problems_to_perfconfigs): + """ + Return a unique list of perfconfigs from the provided dictonary + """ + perfconfigs_set = set() + for perfconfigs_lists in problems_to_perfconfigs.values(): + perfconfigs_set.update(perfconfigs_lists) + return list(perfconfigs_set) + + def get_top_n_perfconfigs_per_problems(self, df, targetColumns): + """ + Identifies the top perfcofnigs for each problem based on a threshold + """ + grouped = df.groupby(targetColumns) + problem_df = {} + for name, grouped_df in grouped: + max_value = grouped_df['TFlops'].max() + threshold = max_value * self.th + problem_df[name] = grouped_df[grouped_df['TFlops'] + >= threshold]['PerfConfig'] + return problem_df + + def find(self): + """ + Finds the minimal set of perfconfigs that cover all + problems using set cover optimizaiton. + Returns : A dictionary containing data types as keys and thier + corresponding selected perfconfigs. + """ + result = {} + unique_data_types = self.df['DataType'].unique() + + targetColumns = [] + if self.op == "gemm": + targetColumns = ['TransA', 'TransB', 'G', 'M', 'K', 'N'] + else: + targetColumns = [ + 'Direction', + 'InputLayout', + 'N', + 'C', + 'H', + 'W', + 'K', + 'Y', + 'X', + 'DilationH', + 'DilationW', + 'StrideH', + 'StrideW', + 'PaddingH', + 'PaddingW'] + + for data_type in unique_data_types: + df_typed = self.df[self.df['DataType'] == data_type] + problems_to_perfconfigs = self.get_top_n_perfconfigs_per_problems( + df_typed, targetColumns) + + problems = problems_to_perfconfigs.keys() + perfconfigs = self.get_unique_perfconfigs_list( + problems_to_perfconfigs) + + n = len(problems) + m = len(perfconfigs) + problem_to_index = { + problem: idx for idx, + problem in enumerate(problems)} + perfconfig_to_index = { + perfconfig: idx for idx, + perfconfig in enumerate(perfconfigs)} + + # Create coverage matrix + A = np.zeros((n, m), dtype=int) + for problem, perfconfig_list in problems_to_perfconfigs.items(): + i = problem_to_index[problem] + for perfconfig in perfconfig_list: + j = perfconfig_to_index[perfconfig] + A[i][j] = 1 + + # Linear programming model to minimize the number of perfconfigs + prob = pulp.LpProblem("SetCoverProblems", pulp.LpMinimize) + x = pulp.LpVariable.dicts("x", range(m), cat='Binary') + prob += pulp.lpSum([x[j]] for j in range(m)) + for i in range(n): + prob += pulp.lpSum([A[i][j] * x[j] + for j in range(m)]) >= 1, f"Cover_problem_{i}" + + prob.solve(pulp.PULP_CBC_CMD(msg=0)) + + selected_configs = [perfconfigs[j] + for j in range(m) if x[j].varValue == 1] + + result[data_type] = selected_configs + + return result + + +def combine_data(input_dir, no_splitK): + """ + Combine all *.debug tuning data into a single file. + """ + tsv_files = glob.glob(os.path.join(input_dir, f"*.debug")) + + dfs = [] + for file in tsv_files: + df = pd.read_csv(file, sep='\t', index_col=None) + df = df[df.columns[1:]] + dfs.append(df) + if not dfs: + return None + new_df = pd.concat(dfs, ignore_index=True) + + # Remove splitK from tuning data + if no_splitK: + df_filtered = new_df[new_df['PerfConfig'].str.split(',').str[6] == '1'] + new_df = df_filtered + + return new_df + + +def print_results(result): + """ + Print selected perfconfigs for each data type. + """ + for datatype, perfconfings in result.items(): + print(f"Datatype: {datatype}") + for idx, perfconfig in enumerate(perfconfings): + print(f" {idx + 1}: {perfconfig}") + + +def main(args=None): + """ + usage: quickTunerGen.py [-h] --input-dir INPUT_DIR --op {gemm,conv} [--th TH] --arch ARCH [--update] [--no-splitK] + usage exsample: python3 quickTunerGen.py --input-dir tunedData --op conv --arch gfx90a --update --no-splitK + """ + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser(prog='quickTunerGen.py') + + parser.add_argument('--input-dir', + required=True, + type=str) + + parser.add_argument('--op', + required=True, + type=str, + choices=["gemm", "conv"]) + + parser.add_argument("--th", + required=False, + type=float, + default=0.93) + + parser.add_argument("--arch", + required=True, + type=str) + + parser.add_argument("--update", + required=False, + default=False, + action='store_true') + + parser.add_argument( + '--no-splitK', + default=False, + action='store_true', + help='Removing the spliK factor from the generated list') + + pargs = parser.parse_args() + + combined_data = combine_data(pargs.input_dir, pargs.no_splitK) + + finder = PerfConfigsFinder(combined_data, pargs) + result = finder.find() + + print_results(result) + + file_writer = FileWriter(pargs) + if pargs.update: + file_writer.update_config_file(result) + + +if __name__ == '__main__': + main(sys.argv[1:])