Skip to content

Commit

Permalink
Update default stride (#1576)
Browse files Browse the repository at this point in the history
* Update default stride value to -1

* Fix format

* Revert "Fix format"

This reverts commit ae0c364.

---------

Co-authored-by: Harisankar Sadasivan <[email protected]>
  • Loading branch information
geyyer and hsadasiv authored Oct 21, 2024
1 parent 794f2d6 commit 3f71093
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
24 changes: 12 additions & 12 deletions example/01_gemm/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ struct ProblemSize final
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 0;
ck::index_t StrideB = 0;
ck::index_t StrideC = 0;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
};

struct ProblemSizeStreamK final
Expand All @@ -40,9 +40,9 @@ struct ProblemSizeStreamK final
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 0;
ck::index_t StrideB = 0;
ck::index_t StrideC = 0;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;

ck::index_t NumSKBlocks = -1;
};
Expand All @@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 0;
ck::index_t StrideB = 0;
ck::index_t StrideC = 0;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;

ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
Expand All @@ -66,9 +66,9 @@ struct ProblemSizeSplitK final
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 0;
ck::index_t StrideB = 0;
ck::index_t StrideC = 0;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;

ck::index_t KBatch = 1;
};
Expand Down
12 changes: 6 additions & 6 deletions example/01_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};

auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is zero, return a default packed stride
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
return static_cast<std::size_t>(col);
}
else
{
return row;
return static_cast<std::size_t>(row);
}
}
else
return stride;
return static_cast<std::size_t>(stride);
};

StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
Expand Down
4 changes: 2 additions & 2 deletions example/01_gemm/run_gemm_example_streamk_v2.inc
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)

auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == 0)
if(stride == -1)
{
// give a chance if stride is 0, return a default packed stride
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
Expand Down
12 changes: 6 additions & 6 deletions example/01_gemm/run_gemm_example_v2.inc
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};

auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is zero, return a default packed stride
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
return static_cast<std::size_t>(col);
}
else
{
return row;
return static_cast<std::size_t>(row);
}
}
else
return stride;
return static_cast<std::size_t>(stride);
};

StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
Expand Down

0 comments on commit 3f71093

Please sign in to comment.