Skip to content

Commit

Permalink
Invoke micro kernels from CUTLASS (#596)
Browse files Browse the repository at this point in the history
* [WIP] Invoke micro kernels from CUTLASS

* Fix: accum should be in row-major

* Multiple fixes and workarounds

Fixes:
- Check `beta=0` or `beta=1` in micro kernel.
- Fix the testing program.

Workarounds:
- Temporarily add __syncthreads in micro kernels.
- Temporarily set data types to float64 to avoid the issue that
  CUTLASS does not support RowMajor layout for float16.

* Fix __syncthreads()

* Adjust test case name

* Fix compiler warnings

* Include micro kernels as needed

* Fix a memory bug in pass/simplify
  • Loading branch information
roastduck authored Mar 15, 2024
1 parent 87bfe21 commit 50909cc
Show file tree
Hide file tree
Showing 24 changed files with 1,001 additions and 63 deletions.
2 changes: 2 additions & 0 deletions grammar/pb_parser.g
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ expr returns [Expr node]
{int ty;} (
'*' {ty = 1;}
| ('%' | MOD) {ty = 2;}
| '/' {ty = 3;} // Exact integer division. We currently use FloorDiv for it.
)
expr1=expr
{
switch (ty)
{
case 1: $node = makeMul($expr0.node, $expr1.node); break;
case 2: $node = makeMod($expr0.node, $expr1.node); break;
case 3: $node = makeFloorDiv($expr0.node, $expr1.node); break;
}
}
| '-' expr0=expr
Expand Down
12 changes: 11 additions & 1 deletion include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ class CompTransientBounds : public BaseClass,
auto dnf = asDNF(_cond);

if (dnf.size() != 1) {
return; // Currently we cannot handle OR
// Currently `transients_` cannot handle OR, leave it as-is to
// `conds_`. But ignore if the condition is marked as `unbound`
for (auto &&item : dnf) {
for (auto &&sub : item) {
if (sub->nodeType() == ASTNodeType::Unbound) {
return;
}
}
}
conds_.emplace_back(_cond);
return;
}

for (auto &&cond : dnf.front()) {
Expand Down
3 changes: 3 additions & 0 deletions include/codegen/code_gen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
Expr globalSize_ = makeIntConst(0);
std::unordered_set<Stmt> streamScopes_;
bool inMatmul_ = false;
std::vector<std::string> neededMicroKernels_;

public:
CodeGenCUDA(const std::vector<FuncParam> &params,
Expand All @@ -49,6 +50,8 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {

std::string gen(const DataType &dtype) override;

const auto &neededMicroKernels() const { return neededMicroKernels_; }

private:
bool inKernel() const;

Expand Down
36 changes: 36 additions & 0 deletions include/cutlass_micro_kernel_property.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef CUTLASS_MICRO_KERNEL_PROPERTY_H
#define CUTLASS_MICRO_KERNEL_PROPERTY_H

#include <expr.h>
#include <sub_tree.h>

namespace freetensor {

struct CutlassMicroKernelProperty : public ASTPart {
int nWarpBatch_, nWarpM_, nWarpN_;
Expr warpIdBatch_, warpIdM_, warpIdN_, laneId_;

template <typename TwarpIdBatch, typename TwarpIdM, typename TwarpIdN,
typename TlaneId>
CutlassMicroKernelProperty(int nWarpBatch, int nWarpM, int nWarpN,
TwarpIdBatch &&warpIdBatch, TwarpIdM &&warpIdM,
TwarpIdN &&warpIdN, TlaneId &&laneId)
: nWarpBatch_(nWarpBatch), nWarpM_(nWarpM), nWarpN_(nWarpN),
warpIdBatch_(std::forward<TwarpIdM>(warpIdBatch)),
warpIdM_(std::forward<TwarpIdM>(warpIdM)),
warpIdN_(std::forward<TwarpIdN>(warpIdN)),
laneId_(std::forward<TlaneId>(laneId)) {}

void compHash() override;
};

inline Ref<CutlassMicroKernelProperty>
deepCopy(const Ref<CutlassMicroKernelProperty> &p) {
return Ref<CutlassMicroKernelProperty>::make(
p->nWarpBatch_, p->nWarpM_, p->nWarpN_, deepCopy(p->warpIdBatch_),
deepCopy(p->warpIdM_), deepCopy(p->warpIdN_), deepCopy(p->laneId_));
}

} // namespace freetensor

#endif // CUTLASS_MICRO_KERNEL_PROPERTY_H
3 changes: 3 additions & 0 deletions include/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Hasher {
static size_t compHash(const Buffer &b);
static size_t compHash(const ReductionItem &r);
static size_t compHash(const ForProperty &p);
static size_t compHash(const CutlassMicroKernelProperty &p);

// stmt
static size_t compHash(const AnyNode &op);
Expand Down Expand Up @@ -104,6 +105,8 @@ class HashComparator {
const Ref<ReductionItem> &rhs) const;
bool operator()(const Ref<ForProperty> &lhs,
const Ref<ForProperty> &rhs) const;
bool operator()(const Ref<CutlassMicroKernelProperty> &lhs,
const Ref<CutlassMicroKernelProperty> &rhs) const;
bool operator()(const AST &lhs, const AST &rhs) const;
};

Expand Down
14 changes: 13 additions & 1 deletion include/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,20 @@ class Mutator {
}

virtual Stmt visit(const MatMul &op) {
Ref<CutlassMicroKernelProperty> cutlassMicroKernelProperty = nullptr;
if (op->cutlassMicroKernelProperty_.isValid()) {
cutlassMicroKernelProperty = Ref<CutlassMicroKernelProperty>::make(
op->cutlassMicroKernelProperty_->nWarpBatch_,
op->cutlassMicroKernelProperty_->nWarpM_,
op->cutlassMicroKernelProperty_->nWarpN_,
(*this)(op->cutlassMicroKernelProperty_->warpIdBatch_),
(*this)(op->cutlassMicroKernelProperty_->warpIdM_),
(*this)(op->cutlassMicroKernelProperty_->warpIdN_),
(*this)(op->cutlassMicroKernelProperty_->laneId_));
}
return makeMatMul(
op->backend_, (*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
op->backend_, std::move(cutlassMicroKernelProperty),
(*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
(*this)(op->alpha_), (*this)(op->beta_), (*this)(op->m_),
(*this)(op->k_), (*this)(op->n_), (*this)(op->lda_),
(*this)(op->ldb_), (*this)(op->ldc_), (*this)(op->stridea_),
Expand Down
3 changes: 2 additions & 1 deletion include/pass/gpu/make_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MakeSync : public Mutator {
const std::unordered_map<ID, ThreadInfo> &loop2thread_;
std::vector<CrossThreadDep> deps_;
std::unordered_map<ID, std::pair<Stmt, bool /* isSyncWarp */>>
syncBeforeFor_, syncBeforeIf_;
syncBeforeFor_, syncBeforeIf_, syncBeforeLib_;
std::unordered_map<ID, std::vector<Stmt>> branchSplittersThen_,
branchSplittersElse_;
LoopVariExprMap variantExprs_;
Expand Down Expand Up @@ -188,6 +188,7 @@ class MakeSync : public Mutator {
Stmt visitStmt(const Stmt &op) override;
Stmt visit(const For &op) override;
Stmt visit(const If &op) override;
Stmt visit(const MatMul &op) override;
};

Stmt makeSync(const Stmt &op, const Ref<GPUTarget> &target);
Expand Down
23 changes: 21 additions & 2 deletions include/schedule/as_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,32 @@ class AsMatMul : public SymbolTable<Mutator> {

AnalyzeLinear analyzeLinear_;

bool done_ = false;
ID resultId_;

// Public matching details
std::vector<bool> dimsABatch_, dimsBBatch_, dimsCBatch_, dimsAM_, dimsAK_,
dimsBK_, dimsBN_, dimsCM_, dimsCN_;
ID defIdA_, defIdB_, defIdC_;

public:
AsMatMul(const ID &loop, MatMulBackend backend)
: loop_(loop), backend_(backend) {}

bool done() const { return done_; }
bool done() const { return resultId_.isValid(); }
const ID &resultId() const { return resultId_; }

const auto &dimsABatch() const { return dimsABatch_; }
const auto &dimsBBatch() const { return dimsBBatch_; }
const auto &dimsCBatch() const { return dimsCBatch_; }
const auto &dimsAM() const { return dimsAM_; }
const auto &dimsAK() const { return dimsAK_; }
const auto &dimsBK() const { return dimsBK_; }
const auto &dimsBN() const { return dimsBN_; }
const auto &dimsCM() const { return dimsCM_; }
const auto &dimsCN() const { return dimsCN_; }
const auto &defIdA() const { return defIdA_; }
const auto &defIdB() const { return defIdB_; }
const auto &defIdC() const { return defIdC_; }

private:
const LinearExpr<int64_t> &analyzeLinear(const Expr &expr);
Expand Down
16 changes: 16 additions & 0 deletions include/schedule/lower_cutlass_micro_block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H
#define FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H

#include <stmt.h>

namespace freetensor {

Stmt lowerCutlassMicroBlock(const Stmt &ast, const ID &matMulId,
const ID &defIdC,
const std::vector<bool> &dimsCBatch,
const std::vector<bool> &dimsCM,
const std::vector<bool> &dimsCN);

}

#endif // FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H
20 changes: 14 additions & 6 deletions include/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ast.h>
#include <buffer.h>
#include <container_utils.h>
#include <cutlass_micro_kernel_property.h>
#include <except.h>
#include <for_property.h>
#include <reduce_op.h>
Expand Down Expand Up @@ -465,14 +466,16 @@ enum class MatMulBackend : size_t {
Mkl = 0,
Cublas,
Cutlass,
CutlassMicroBlock, // CUTLASS's micro kernel, invocable by a single
// thread-block
// ------ THE FOLLOWING BACKENDS CAN ONLY BE LOWERED TO ------
CutlassMicroThread, // CUTLASS's micro kernel, invocable by a single thread
// ----------------------------
NumBackends
};

constexpr std::array matMulBackendNames = {
"mkl",
"cublas",
"cutlass",
"mkl", "cublas", "cutlass", "cutlass-micro-block", "cutlass-micro-thread",
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

Expand All @@ -498,6 +501,8 @@ inline MatMulBackend parseMatMulBackend(const std::string &_str) {
class MatMulNode : public StmtNode {
public:
MatMulBackend backend_;
SubTree<CutlassMicroKernelProperty, NullPolicy::Nullable>
cutlassMicroKernelProperty_ = ChildOf{this};

// c_ = alpha_ * a_ * b_ + beta_ * c_
// a_ is an m_ * k_ matrix
Expand Down Expand Up @@ -527,9 +532,11 @@ class MatMulNode : public StmtNode {
};
typedef Ref<MatMulNode> MatMul;
inline Stmt
makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c,
const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k,
const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc,
makeMatMul(MatMulBackend backend,
const Ref<CutlassMicroKernelProperty> &cutlassMicroKernelProperty,
const Expr &a, const Expr &b, const Expr &c, const Expr &alpha,
const Expr &beta, const Expr &m, const Expr &k, const Expr &n,
const Expr &lda, const Expr &ldb, const Expr &ldc,
const Expr &stridea, const Expr &strideb, const Expr &stridec,
const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor,
bool cIsRowMajor, const Stmt &equivalent,
Expand All @@ -539,6 +546,7 @@ makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c,
s->metadata() = metadata;
s->setId(id);
s->backend_ = backend;
s->cutlassMicroKernelProperty_ = cutlassMicroKernelProperty;
s->a_ = a;
s->b_ = b;
s->c_ = c;
Expand Down
26 changes: 26 additions & 0 deletions runtime/micro_kernel/matmul/cutlass/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H
#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H

#if (defined(__CUDA_ARCH__)) // Device code

#if (__CUDA_ARCH__ >= 800)
#include "gemm_sm80.h"
#else
#error "Unsupported architecture"
#endif

#else // Host code

// Only declaration is needed
template <int M, int N, int K, int num_warp_batch, int num_warp_m,
int num_warp_n, bool trans_A, bool trans_B, typename A_type,
typename B_type, typename C_type>
__device__ void matmul_thread(const A_type *pA, const B_type *pB, C_type *accum,
int lda, int ldb, int stridea, int strideb,
int stridec, double alpha, double beta,
int warp_id_batch, int warp_id_m, int warp_id_n,
int lane_id);

#endif

#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H
Loading

0 comments on commit 50909cc

Please sign in to comment.