diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index eb1738e760..d08196924b 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -29,9 +29,9 @@ struct ProblemSize final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 0; - ck::index_t StrideB = 0; - ck::index_t StrideC = 0; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; }; struct ProblemSizeStreamK final @@ -40,9 +40,9 @@ struct ProblemSizeStreamK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 0; - ck::index_t StrideB = 0; - ck::index_t StrideC = 0; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; ck::index_t NumSKBlocks = -1; }; @@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 0; - ck::index_t StrideB = 0; - ck::index_t StrideC = 0; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; ck::index_t Grid_size = -1; // defaults to max occupancy ck::index_t Streamk_sel = 1; // defaults to 1-tile SK @@ -66,9 +66,9 @@ struct ProblemSizeSplitK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 0; - ck::index_t StrideB = 0; - ck::index_t StrideC = 0; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; ck::index_t KBatch = 1; }; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index f66d2adc11..fe12998e35 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -116,21 +116,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) }; auto f_get_default_stride = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(stride == 0) + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) { - // give a chance if stride is zero, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { - return col; + return static_cast(col); } else { - return row; + return static_cast(row); } } else - return stride; + return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 32bd3a19a6..6679f95157 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == 0) + if(stride == -1) { - // give a chance if stride is 0, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { return static_cast(col); diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index ad7238f0dd..0bcee658b9 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -115,21 +115,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) }; auto f_get_default_stride = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(stride == 0) + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) { - // give a chance if stride is zero, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { - return col; + return static_cast(col); } else { - return row; + return static_cast(row); } } else - return stride; + return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{});