Skip to content

Commit

Permalink
ft16 support for CuTe micro kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
cerf-volantWang committed Mar 31, 2024
1 parent e50bd9b commit a897c8e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 68 deletions.
36 changes: 2 additions & 34 deletions runtime/micro_kernel/matmul/cutlass/gemm_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,19 @@ using namespace cute;
template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
template <> struct DispatchInstruction<half_t, half_t, half_t> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
};
template <> struct DispatchInstruction<double, double, double> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
};
#endif

template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
typename std::enable_if<K % 16 == 0>::type> {
using Layout = Layout<Shape<Int<N>, Int<K>>, Stride<Int<K>, _1>>;
using Copy = DefaultCopy;
};

template <int N, int K>
struct OperandTraits<16, N, K, false,
typename std::enable_if<N % 16 == 0>::type> {
using Layout = Layout<Shape<Int<N>, Int<K>>, Stride<_1, Int<N>>>;
using Copy = DefaultCopy;
};

template <int N, int K>
struct OperandTraits<64, N, K, true,
typename std::enable_if<K % 16 == 0>::type> {
using Layout = Layout<Shape<Int<N>, Int<K>>, Stride<Int<K>, _1>>;
using Copy = DefaultCopy;
};

template <int N, int K>
struct OperandTraits<64, N, K, false,
typename std::enable_if<N % 16 == 0>::type> {
using Layout = Layout<Shape<Int<N>, Int<K>>, Stride<_1, Int<N>>>;
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<K>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<N>>>>::type;
using Copy = DefaultCopy;
};

Expand Down
33 changes: 0 additions & 33 deletions src/schedule/lower_cutlass_micro_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
int nDimsCAll = op->indices_.size();
ASSERT(nDimsCAll >=
9); // See comments in `lowerCutlassMicroBlock` below
<<<<<<< HEAD
switch (DType) {
case BaseDataType::Float64: {
auto batchInWarpPartition =
Expand Down Expand Up @@ -216,26 +215,6 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
break;
}
}
=======
auto batchInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_);
auto mInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_);
auto nInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_);
auto mInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 3],
makeFloorDiv(prop_->laneId_, makeIntConst(4)));
auto nInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 2],
makeMod(prop_->laneId_, makeIntConst(4)));

ret = makeIf(
makeLAnd(makeLAnd(batchInWarpPartition,
makeLAnd(mInWarpPartition, nInWarpPartition)),
makeLAnd(mInThreadPartition, nInThreadPartition)),
ret);
>>>>>>> master
}
return ret;
}
Expand Down Expand Up @@ -280,17 +259,6 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
int nDimsCAll = c->indices_.size();
ASSERT(nDimsCAll >=
9); // See comments in `lowerCutlassMicroBlock` below
<<<<<<< HEAD
=======
c->indices_[nDimsCAll - 9] = warpIdBatch;
c->indices_[nDimsCAll - 4] = warpIdM; // m warps
c->indices_[nDimsCAll - 3] =
makeFloorDiv(laneId, makeIntConst(4)); // m threads
c->indices_[nDimsCAll - 5] = warpIdN; // n warps
c->indices_[nDimsCAll - 2] =
makeMod(laneId, makeIntConst(4)); // n threads
>>>>>>> master

switch (DType) {
case BaseDataType::Float64: {
c->indices_[nDimsCAll - 9] = warpIdBatch;
Expand Down Expand Up @@ -504,7 +472,6 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId,
}
}


// Lower to CutlassMicroThread
LowerCutlassMicroBlock lowerCutlassMicroBlock{matMulId, nWarpBatch, nWarpM,
nWarpN};
Expand Down
2 changes: 1 addition & 1 deletion test/70.program/test_program_with_micro_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ def matmul(a: ft.Var[(M, K), "float16"], b: ft.Var[(K, N), "float16"]):
b_arr = ft.array(b_torch)
y_arr = exe(a_arr, b_arr)
y_torch = y_arr.torch()
assert torch.all(torch.isclose(y_torch, y_std, rtol = 2e-2))
assert torch.all(torch.isclose(y_torch, y_std, rtol=2e-2))

0 comments on commit a897c8e

Please sign in to comment.