Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invoke micro kernels from CUTLASS #596

Merged
merged 9 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions grammar/pb_parser.g
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ expr returns [Expr node]
{int ty;} (
'*' {ty = 1;}
| ('%' | MOD) {ty = 2;}
| '/' {ty = 3;} // Exact integer division. We currently use FloorDiv for it.
)
expr1=expr
{
switch (ty)
{
case 1: $node = makeMul($expr0.node, $expr1.node); break;
case 2: $node = makeMod($expr0.node, $expr1.node); break;
case 3: $node = makeFloorDiv($expr0.node, $expr1.node); break;
}
}
| '-' expr0=expr
Expand Down
12 changes: 11 additions & 1 deletion include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ class CompTransientBounds : public BaseClass,
auto dnf = asDNF(_cond);

if (dnf.size() != 1) {
return; // Currently we cannot handle OR
// Currently `transients_` cannot handle OR, leave it as-is to
// `conds_`. But ignore if the condition is marked as `unbound`
for (auto &&item : dnf) {
for (auto &&sub : item) {
if (sub->nodeType() == ASTNodeType::Unbound) {
return;
}
}
}
conds_.emplace_back(_cond);
return;
}

for (auto &&cond : dnf.front()) {
Expand Down
3 changes: 3 additions & 0 deletions include/codegen/code_gen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
Expr globalSize_ = makeIntConst(0);
std::unordered_set<Stmt> streamScopes_;
bool inMatmul_ = false;
std::vector<std::string> neededMicroKernels_;

public:
CodeGenCUDA(const std::vector<FuncParam> &params,
Expand All @@ -49,6 +50,8 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {

std::string gen(const DataType &dtype) override;

const auto &neededMicroKernels() const { return neededMicroKernels_; }

private:
bool inKernel() const;

Expand Down
36 changes: 36 additions & 0 deletions include/cutlass_micro_kernel_property.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef CUTLASS_MICRO_KERNEL_PROPERTY_H
#define CUTLASS_MICRO_KERNEL_PROPERTY_H

#include <expr.h>
#include <sub_tree.h>

namespace freetensor {

struct CutlassMicroKernelProperty : public ASTPart {
int nWarpBatch_, nWarpM_, nWarpN_;
Expr warpIdBatch_, warpIdM_, warpIdN_, laneId_;

template <typename TwarpIdBatch, typename TwarpIdM, typename TwarpIdN,
typename TlaneId>
CutlassMicroKernelProperty(int nWarpBatch, int nWarpM, int nWarpN,
TwarpIdBatch &&warpIdBatch, TwarpIdM &&warpIdM,
TwarpIdN &&warpIdN, TlaneId &&laneId)
: nWarpBatch_(nWarpBatch), nWarpM_(nWarpM), nWarpN_(nWarpN),
warpIdBatch_(std::forward<TwarpIdM>(warpIdBatch)),
warpIdM_(std::forward<TwarpIdM>(warpIdM)),
warpIdN_(std::forward<TwarpIdN>(warpIdN)),
laneId_(std::forward<TlaneId>(laneId)) {}

void compHash() override;
};

inline Ref<CutlassMicroKernelProperty>
deepCopy(const Ref<CutlassMicroKernelProperty> &p) {
return Ref<CutlassMicroKernelProperty>::make(
p->nWarpBatch_, p->nWarpM_, p->nWarpN_, deepCopy(p->warpIdBatch_),
deepCopy(p->warpIdM_), deepCopy(p->warpIdN_), deepCopy(p->laneId_));
}

} // namespace freetensor

#endif // CUTLASS_MICRO_KERNEL_PROPERTY_H
3 changes: 3 additions & 0 deletions include/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Hasher {
static size_t compHash(const Buffer &b);
static size_t compHash(const ReductionItem &r);
static size_t compHash(const ForProperty &p);
static size_t compHash(const CutlassMicroKernelProperty &p);

// stmt
static size_t compHash(const AnyNode &op);
Expand Down Expand Up @@ -104,6 +105,8 @@ class HashComparator {
const Ref<ReductionItem> &rhs) const;
bool operator()(const Ref<ForProperty> &lhs,
const Ref<ForProperty> &rhs) const;
bool operator()(const Ref<CutlassMicroKernelProperty> &lhs,
const Ref<CutlassMicroKernelProperty> &rhs) const;
bool operator()(const AST &lhs, const AST &rhs) const;
};

Expand Down
14 changes: 13 additions & 1 deletion include/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,20 @@ class Mutator {
}

virtual Stmt visit(const MatMul &op) {
Ref<CutlassMicroKernelProperty> cutlassMicroKernelProperty = nullptr;
if (op->cutlassMicroKernelProperty_.isValid()) {
cutlassMicroKernelProperty = Ref<CutlassMicroKernelProperty>::make(
op->cutlassMicroKernelProperty_->nWarpBatch_,
op->cutlassMicroKernelProperty_->nWarpM_,
op->cutlassMicroKernelProperty_->nWarpN_,
(*this)(op->cutlassMicroKernelProperty_->warpIdBatch_),
(*this)(op->cutlassMicroKernelProperty_->warpIdM_),
(*this)(op->cutlassMicroKernelProperty_->warpIdN_),
(*this)(op->cutlassMicroKernelProperty_->laneId_));
}
return makeMatMul(
op->backend_, (*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
op->backend_, std::move(cutlassMicroKernelProperty),
(*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
(*this)(op->alpha_), (*this)(op->beta_), (*this)(op->m_),
(*this)(op->k_), (*this)(op->n_), (*this)(op->lda_),
(*this)(op->ldb_), (*this)(op->ldc_), (*this)(op->stridea_),
Expand Down
3 changes: 2 additions & 1 deletion include/pass/gpu/make_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MakeSync : public Mutator {
const std::unordered_map<ID, ThreadInfo> &loop2thread_;
std::vector<CrossThreadDep> deps_;
std::unordered_map<ID, std::pair<Stmt, bool /* isSyncWarp */>>
syncBeforeFor_, syncBeforeIf_;
syncBeforeFor_, syncBeforeIf_, syncBeforeLib_;
std::unordered_map<ID, std::vector<Stmt>> branchSplittersThen_,
branchSplittersElse_;
LoopVariExprMap variantExprs_;
Expand Down Expand Up @@ -188,6 +188,7 @@ class MakeSync : public Mutator {
Stmt visitStmt(const Stmt &op) override;
Stmt visit(const For &op) override;
Stmt visit(const If &op) override;
Stmt visit(const MatMul &op) override;
};

Stmt makeSync(const Stmt &op, const Ref<GPUTarget> &target);
Expand Down
23 changes: 21 additions & 2 deletions include/schedule/as_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,32 @@ class AsMatMul : public SymbolTable<Mutator> {

AnalyzeLinear analyzeLinear_;

bool done_ = false;
ID resultId_;

// Public matching details
std::vector<bool> dimsABatch_, dimsBBatch_, dimsCBatch_, dimsAM_, dimsAK_,
dimsBK_, dimsBN_, dimsCM_, dimsCN_;
ID defIdA_, defIdB_, defIdC_;

public:
AsMatMul(const ID &loop, MatMulBackend backend)
: loop_(loop), backend_(backend) {}

bool done() const { return done_; }
bool done() const { return resultId_.isValid(); }
const ID &resultId() const { return resultId_; }

const auto &dimsABatch() const { return dimsABatch_; }
const auto &dimsBBatch() const { return dimsBBatch_; }
const auto &dimsCBatch() const { return dimsCBatch_; }
const auto &dimsAM() const { return dimsAM_; }
const auto &dimsAK() const { return dimsAK_; }
const auto &dimsBK() const { return dimsBK_; }
const auto &dimsBN() const { return dimsBN_; }
const auto &dimsCM() const { return dimsCM_; }
const auto &dimsCN() const { return dimsCN_; }
const auto &defIdA() const { return defIdA_; }
const auto &defIdB() const { return defIdB_; }
const auto &defIdC() const { return defIdC_; }

private:
const LinearExpr<int64_t> &analyzeLinear(const Expr &expr);
Expand Down
16 changes: 16 additions & 0 deletions include/schedule/lower_cutlass_micro_block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H
#define FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H

#include <stmt.h>

namespace freetensor {

Stmt lowerCutlassMicroBlock(const Stmt &ast, const ID &matMulId,
const ID &defIdC,
const std::vector<bool> &dimsCBatch,
const std::vector<bool> &dimsCM,
const std::vector<bool> &dimsCN);

}

#endif // FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H
20 changes: 14 additions & 6 deletions include/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ast.h>
#include <buffer.h>
#include <container_utils.h>
#include <cutlass_micro_kernel_property.h>
#include <except.h>
#include <for_property.h>
#include <reduce_op.h>
Expand Down Expand Up @@ -465,14 +466,16 @@ enum class MatMulBackend : size_t {
Mkl = 0,
Cublas,
Cutlass,
CutlassMicroBlock, // CUTLASS's micro kernel, invocable by a single
// thread-block
// ------ THE FOLLOWING BACKENDS CAN ONLY BE LOWERED TO ------
CutlassMicroThread, // CUTLASS's micro kernel, invocable by a single thread
// ----------------------------
NumBackends
};

constexpr std::array matMulBackendNames = {
"mkl",
"cublas",
"cutlass",
"mkl", "cublas", "cutlass", "cutlass-micro-block", "cutlass-micro-thread",
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

Expand All @@ -498,6 +501,8 @@ inline MatMulBackend parseMatMulBackend(const std::string &_str) {
class MatMulNode : public StmtNode {
public:
MatMulBackend backend_;
SubTree<CutlassMicroKernelProperty, NullPolicy::Nullable>
cutlassMicroKernelProperty_ = ChildOf{this};

// c_ = alpha_ * a_ * b_ + beta_ * c_
// a_ is an m_ * k_ matrix
Expand Down Expand Up @@ -527,9 +532,11 @@ class MatMulNode : public StmtNode {
};
typedef Ref<MatMulNode> MatMul;
inline Stmt
makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c,
const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k,
const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc,
makeMatMul(MatMulBackend backend,
const Ref<CutlassMicroKernelProperty> &cutlassMicroKernelProperty,
const Expr &a, const Expr &b, const Expr &c, const Expr &alpha,
const Expr &beta, const Expr &m, const Expr &k, const Expr &n,
const Expr &lda, const Expr &ldb, const Expr &ldc,
const Expr &stridea, const Expr &strideb, const Expr &stridec,
const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor,
bool cIsRowMajor, const Stmt &equivalent,
Expand All @@ -539,6 +546,7 @@ makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c,
s->metadata() = metadata;
s->setId(id);
s->backend_ = backend;
s->cutlassMicroKernelProperty_ = cutlassMicroKernelProperty;
s->a_ = a;
s->b_ = b;
s->c_ = c;
Expand Down
26 changes: 26 additions & 0 deletions runtime/micro_kernel/matmul/cutlass/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H
#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H

#if (defined(__CUDA_ARCH__)) // Device code

#if (__CUDA_ARCH__ >= 800)
#include "gemm_sm80.h"
#else
#error "Unsupported architecture"
#endif

#else // Host code

// Only declaration is needed
template <int M, int N, int K, int num_warp_batch, int num_warp_m,
int num_warp_n, bool trans_A, bool trans_B, typename A_type,
typename B_type, typename C_type>
__device__ void matmul_thread(const A_type *pA, const B_type *pB, C_type *accum,
int lda, int ldb, int stridea, int strideb,
int stridec, double alpha, double beta,
int warp_id_batch, int warp_id_m, int warp_id_n,
int lane_id);

#endif

#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H
Loading
Loading