Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float16 support for CuTe micro kernel #602

Merged
merged 8 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions runtime/micro_kernel/matmul/cutlass/gemm_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ using namespace cute;
template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction;

template <> struct DispatchInstruction<half_t, half_t, half_t> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
};
template <> struct DispatchInstruction<double, double, double> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
};
Expand Down
227 changes: 177 additions & 50 deletions src/schedule/lower_cutlass_micro_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <schedule/var_reorder.h>
#include <schedule/var_split.h>
#include <schedule/var_unsqueeze.h>
#include <type/data_type.h>

namespace freetensor {

Expand All @@ -18,6 +19,7 @@
class FixTransposeAndGetPartition : public Mutator {
ID matMulId_;
int64_t nWarpBatch_ = 0, nWarpM_ = 0, nWarpN_ = 0;
DataType dtypeA_, dtypeB_, dtypeC_;

public:
FixTransposeAndGetPartition(const ID &matMulId) : matMulId_(matMulId) {}
Expand All @@ -26,6 +28,10 @@
auto nWarpM() const { return nWarpM_; }
auto nWarpN() const { return nWarpN_; }

auto dtypeA() const { return dtypeA_; }
auto dtypeB() const { return dtypeB_; }
auto dtypeC() const { return dtypeC_; }

private:
std::tuple<int, int, int> computeWarpPartition(int64_t batch, int64_t m,
int64_t n, int64_t k,
Expand Down Expand Up @@ -79,6 +85,10 @@
if (op->id() == matMulId_) {
ASSERT(op->backend_ == MatMulBackend::CutlassMicroBlock);

dtypeA_ = op->a_->dtype();
dtypeB_ = op->b_->dtype();
dtypeC_ = op->c_->dtype();

// C is only supported for densely packed row-major layout in
// registers
if (!op->cIsRowMajor_) {
Expand Down Expand Up @@ -144,6 +154,8 @@
ID matMulId_;
int64_t nWarpBatch_ = 0, nWarpM_ = 0, nWarpN_ = 0;

DataType dtypeA_, dtypeB_, dtypeC_;

Ref<CutlassMicroKernelProperty> prop_;
bool inMicroKernel_ = false;

Expand All @@ -160,24 +172,59 @@
int nDimsCAll = op->indices_.size();
ASSERT(nDimsCAll >=
9); // See comments in `lowerCutlassMicroBlock` below
auto batchInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_);
auto mInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_);
auto nInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_);
auto mInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 3],
makeFloorDiv(prop_->laneId_, makeIntConst(4)));
auto nInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 2],
makeMod(prop_->laneId_, makeIntConst(4)));

ret = makeIf(
makeLAnd(makeLAnd(batchInWarpPartition,
makeLAnd(mInWarpPartition, nInWarpPartition)),
makeLAnd(mInThreadPartition, nInThreadPartition)),
ret);
auto dtype = dtypeA_.base();
ASSERT(dtype == dtypeB_.base());
ASSERT(dtype == dtypeC_.base());
switch (dtype) {
case DataType::Float64: {
auto batchInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_);
auto mInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_);
auto nInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_);
auto mInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 3],
makeFloorDiv(prop_->laneId_, makeIntConst(4)));
auto nInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 2],
makeMod(prop_->laneId_, makeIntConst(4)));
ret = makeIf(
makeLAnd(
makeLAnd(batchInWarpPartition,
makeLAnd(mInWarpPartition, nInWarpPartition)),
makeLAnd(mInThreadPartition, nInThreadPartition)),
ret);
break;
}

case DataType::Float16: {
auto batchInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 10], prop_->warpIdBatch_);
auto mInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdM_);
auto nInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 6], prop_->warpIdN_);
auto mInThreadPartition = makeEQ(
op->indices_[nDimsCAll - 3],
makeFloorDiv(prop_->laneId_, makeIntConst(4))); // m threads
auto nInThreadPartition = makeEQ(
op->indices_[nDimsCAll - 1],
makeMod(prop_->laneId_, makeIntConst(4))); // n threads
ret = makeIf(
makeLAnd(
makeLAnd(batchInWarpPartition,
makeLAnd(mInWarpPartition, nInWarpPartition)),
makeLAnd(mInThreadPartition, nInThreadPartition)),
ret);
break;
}

default:
throw InvalidSchedule(FT_MSG
<< "Unsupported data types: only Float16 "
"and Float64 are supported.");
}
Dismissed Show dismissed Hide dismissed
}
return ret;
}
Expand All @@ -191,6 +238,13 @@
throw InvalidSchedule("Micro kernels cannot nest each other");
}

dtypeA_ = _op->a_->dtype();
dtypeB_ = _op->b_->dtype();
dtypeC_ = _op->c_->dtype();
auto dtype = dtypeA_.base();
ASSERT(dtype == dtypeB_.base());
ASSERT(dtype == dtypeC_.base());

// Here we use `threadIdx.x` for threads in a warp, and
// `threadIdx.y` for warps, because putting everthing into a single
// `threadIdx.x` will make the expressions to complicated to solve.
Expand Down Expand Up @@ -222,14 +276,32 @@
int nDimsCAll = c->indices_.size();
ASSERT(nDimsCAll >=
9); // See comments in `lowerCutlassMicroBlock` below
c->indices_[nDimsCAll - 9] = warpIdBatch;
c->indices_[nDimsCAll - 4] = warpIdM; // m warps
c->indices_[nDimsCAll - 3] =
makeFloorDiv(laneId, makeIntConst(4)); // m threads
c->indices_[nDimsCAll - 5] = warpIdN; // n warps
c->indices_[nDimsCAll - 2] =
makeMod(laneId, makeIntConst(4)); // n threads

switch (dtype) {
case DataType::Float64:
c->indices_[nDimsCAll - 9] = warpIdBatch;
c->indices_[nDimsCAll - 4] = warpIdM; // m warps
c->indices_[nDimsCAll - 3] =
makeFloorDiv(laneId, makeIntConst(4)); // m threads
c->indices_[nDimsCAll - 5] = warpIdN; // n warps
c->indices_[nDimsCAll - 2] =
makeMod(laneId, makeIntConst(4)); // n threads
break;

case DataType::Float16:
c->indices_[nDimsCAll - 10] = warpIdBatch;
c->indices_[nDimsCAll - 5] = warpIdM; // m warps
c->indices_[nDimsCAll - 3] =
makeFloorDiv(laneId, makeIntConst(4)); // m threads
c->indices_[nDimsCAll - 6] = warpIdN; // n warps
c->indices_[nDimsCAll - 1] =
makeMod(laneId, makeIntConst(4)); // n threads
break;

default:
throw InvalidSchedule(FT_MSG
<< "Unsupported data types: only Float16 "
"and Float64 are supported.");
}
Dismissed Show dismissed Hide dismissed
op->backend_ = MatMulBackend::CutlassMicroThread;
op->cutlassMicroKernelProperty_ = prop_;

Expand Down Expand Up @@ -274,10 +346,18 @@
auto nWarpBatch = fixTransposeAndGetPartition.nWarpBatch();
auto nWarpM = fixTransposeAndGetPartition.nWarpM();
auto nWarpN = fixTransposeAndGetPartition.nWarpN();
auto dtypeA = fixTransposeAndGetPartition.dtypeA();
auto dtypeB = fixTransposeAndGetPartition.dtypeB();
auto dtypeC = fixTransposeAndGetPartition.dtypeC();
auto dtype = dtypeA.base();
ASSERT(dtype == dtypeB.base());
ASSERT(dtype == dtypeC.base());

// Partition C to each threads by layout-manipulation schedules. We have
// checked we are in TryVarReorder mode in schedule/as_matmul.cc. The
// resulting layout will be [
// resulting layout will be:
//
// float64: [
// ...: other leading dims,
// -9: batch warps,
// -8: batch serial,
Expand All @@ -290,6 +370,20 @@
// -1: n 2-tiles,
// ]
//
// float16: [
// ...: other leading dims,
// -10: batch warps,
// -9: batch serial,
// -8: n 16-tiles,
// -7: m 32-tiles,
// -6: n warps
// -5: m warps,
// -4: m 2-tiles,
// -3: m threads,
// -2: n 2-tiles,
// -1: n threads,
// ]
//
// See
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
// for 8x8 partition inside warps
Expand Down Expand Up @@ -337,31 +431,64 @@
ast = varUnsqueeze(ast, defIdC, nDimsCOthers);
}

// clang-format off
ast = varSplit(
ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch);
ast = varSplit(
ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, 16, -1);
ast = varSplit(
ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, -1, nWarpM);
ast = varSplit(
ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, 16, -1);
ast = varSplit(
ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, -1, nWarpN);
ast = varSplit(
ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, 2, -1);
// vector for reordering
std::vector<int> vec;
for(int i=0; i<=nDimsCOthers+1; i++)
for (int i = 0; i <= nDimsCOthers + 1; i++)
vec.push_back(i);
vec.push_back(nDimsCOthers+5);
vec.push_back(nDimsCOthers+2);
vec.push_back(nDimsCOthers+6);
vec.push_back(nDimsCOthers+3);
vec.push_back(nDimsCOthers+4);
vec.push_back(nDimsCOthers+7);
vec.push_back(nDimsCOthers+8);
ast = varReorderImpl(ast, defIdC, vec, true);
// clang-format on
switch (dtype) {
case DataType::Float64:
ast = varSplit(ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize,
-1, nWarpBatch);
ast = varSplit(ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize,
16, -1);
ast = varSplit(ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize,
-1, nWarpM);
ast = varSplit(ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize,
16, -1);
ast = varSplit(ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize,
-1, nWarpN);
ast = varSplit(ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize,
2, -1);
vec.push_back(nDimsCOthers + 5);
vec.push_back(nDimsCOthers + 2);
vec.push_back(nDimsCOthers + 6);
vec.push_back(nDimsCOthers + 3);
vec.push_back(nDimsCOthers + 4);
vec.push_back(nDimsCOthers + 7);
vec.push_back(nDimsCOthers + 8);
ast = varReorderImpl(ast, defIdC, vec, true);
break;

case DataType::Float16:
ast = varSplit(ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize,
-1, nWarpBatch);
ast = varSplit(ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize,
32, -1);
ast = varSplit(ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize,
-1, nWarpM);
ast = varSplit(ast, defIdC, nDimsCOthers + 4, VarSplitMode::FixedSize,
-1, 2);
ast = varSplit(ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize,
16, -1);
ast = varSplit(ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize,
-1, nWarpN);
ast = varSplit(ast, defIdC, nDimsCOthers + 8, VarSplitMode::FixedSize,
2, -1);
vec.push_back(nDimsCOthers + 6);
vec.push_back(nDimsCOthers + 2);
vec.push_back(nDimsCOthers + 7);
vec.push_back(nDimsCOthers + 3);
vec.push_back(nDimsCOthers + 4);
vec.push_back(nDimsCOthers + 5);
vec.push_back(nDimsCOthers + 9);
vec.push_back(nDimsCOthers + 8);
ast = varReorderImpl(ast, defIdC, vec, true);
break;

default:
throw InvalidSchedule(FT_MSG << "Unsupported data types: only Float16 "
"and Float64 are supported.");
}
Dismissed Show dismissed Hide dismissed

// Lower to CutlassMicroThread
LowerCutlassMicroBlock lowerCutlassMicroBlock{matMulId, nWarpBatch, nWarpM,
Expand Down
27 changes: 18 additions & 9 deletions test/70.program/test_program_with_micro_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

@pytest.mark.skipif(not ft.with_pytorch() or not ft.with_cuda(),
reason="requires PyTorch and CUDA")
def test_matmul_float64():
@pytest.mark.parametrize('dtype', ['float16', 'float64'])
def test_matmul(dtype):

M = N = K = 5000
block_n = block_m = 128
Expand All @@ -17,18 +18,18 @@ def test_matmul_float64():
with target:

@ft.transform
def matmul(a: ft.Var[(M, K), "float64"], b: ft.Var[(K, N), "float64"]):
c = ft.empty((M, N), "float64")
def matmul(a: ft.Var[(M, K), dtype], b: ft.Var[(K, N), dtype]):
c = ft.empty((M, N), dtype)
#! label: blk_m
for i in range(0, M, block_m):
#! label: blk_n
for j in range(0, N, block_n):
#! label: aa
aa = ft.empty((block_m, block_k), "float64")
aa = ft.empty((block_m, block_k), dtype)
#! label: bb
bb = ft.empty((block_k, block_n), "float64")
bb = ft.empty((block_k, block_n), dtype)
#! label: cc
cc = ft.empty((block_m, block_n), "float64")
cc = ft.empty((block_m, block_n), dtype)
#! label: zero_cc
for ii in range(block_m):
for jj in range(block_n):
Expand Down Expand Up @@ -86,11 +87,19 @@ def matmul(a: ft.Var[(M, K), "float64"], b: ft.Var[(K, N), "float64"]):

import torch

a_torch = torch.rand(M, K, dtype=torch.float64).cuda()
b_torch = torch.rand(K, N, dtype=torch.float64).cuda()
dtype_to_torch = {
'float16': torch.float16,
'float64': torch.float64,
}
a_torch = torch.rand(M, K, dtype=dtype_to_torch[dtype]).cuda()
b_torch = torch.rand(K, N, dtype=dtype_to_torch[dtype]).cuda()
y_std = a_torch @ b_torch
a_arr = ft.array(a_torch)
b_arr = ft.array(b_torch)
y_arr = exe(a_arr, b_arr)
y_torch = y_arr.torch()
assert torch.all(torch.isclose(y_torch, y_std))

if dtype == 'float16':
assert torch.all(torch.isclose(y_torch, y_std, rtol=2e-2))
else:
assert torch.all(torch.isclose(y_torch, y_std))
Loading