diff --git a/grammar/pb_parser.g b/grammar/pb_parser.g index 6c473a094..c26a9b9c3 100644 --- a/grammar/pb_parser.g +++ b/grammar/pb_parser.g @@ -103,6 +103,7 @@ expr returns [Expr node] {int ty;} ( '*' {ty = 1;} | ('%' | MOD) {ty = 2;} + | '/' {ty = 3;} // Exact integer division. We currently use FloorDiv for it. ) expr1=expr { @@ -110,6 +111,7 @@ expr returns [Expr node] { case 1: $node = makeMul($expr0.node, $expr1.node); break; case 2: $node = makeMod($expr0.node, $expr1.node); break; + case 3: $node = makeFloorDiv($expr0.node, $expr1.node); break; } } | '-' expr0=expr diff --git a/include/analyze/comp_transient_bounds.h b/include/analyze/comp_transient_bounds.h index 7749a18b0..998740754 100644 --- a/include/analyze/comp_transient_bounds.h +++ b/include/analyze/comp_transient_bounds.h @@ -71,7 +71,17 @@ class CompTransientBounds : public BaseClass, auto dnf = asDNF(_cond); if (dnf.size() != 1) { - return; // Currently we cannot handle OR + // Currently `transients_` cannot handle OR, leave it as-is to + // `conds_`. But ignore if the condition is marked as `unbound` + for (auto &&item : dnf) { + for (auto &&sub : item) { + if (sub->nodeType() == ASTNodeType::Unbound) { + return; + } + } + } + conds_.emplace_back(_cond); + return; } for (auto &&cond : dnf.front()) { diff --git a/include/codegen/code_gen_cuda.h b/include/codegen/code_gen_cuda.h index a8d006bcb..1c6e20136 100644 --- a/include/codegen/code_gen_cuda.h +++ b/include/codegen/code_gen_cuda.h @@ -30,6 +30,7 @@ class CodeGenCUDA : public CodeGenC { Expr globalSize_ = makeIntConst(0); std::unordered_set streamScopes_; bool inMatmul_ = false; + std::vector neededMicroKernels_; public: CodeGenCUDA(const std::vector ¶ms, @@ -49,6 +50,8 @@ class CodeGenCUDA : public CodeGenC { std::string gen(const DataType &dtype) override; + const auto &neededMicroKernels() const { return neededMicroKernels_; } + private: bool inKernel() const; diff --git a/include/cutlass_micro_kernel_property.h b/include/cutlass_micro_kernel_property.h new file mode 100644 index 000000000..a922b5672 --- /dev/null +++ b/include/cutlass_micro_kernel_property.h @@ -0,0 +1,36 @@ +#ifndef CUTLASS_MICRO_KERNEL_PROPERTY_H +#define CUTLASS_MICRO_KERNEL_PROPERTY_H + +#include +#include + +namespace freetensor { + +struct CutlassMicroKernelProperty : public ASTPart { + int nWarpBatch_, nWarpM_, nWarpN_; + Expr warpIdBatch_, warpIdM_, warpIdN_, laneId_; + + template + CutlassMicroKernelProperty(int nWarpBatch, int nWarpM, int nWarpN, + TwarpIdBatch &&warpIdBatch, TwarpIdM &&warpIdM, + TwarpIdN &&warpIdN, TlaneId &&laneId) + : nWarpBatch_(nWarpBatch), nWarpM_(nWarpM), nWarpN_(nWarpN), + warpIdBatch_(std::forward(warpIdBatch)), + warpIdM_(std::forward(warpIdM)), + warpIdN_(std::forward(warpIdN)), + laneId_(std::forward(laneId)) {} + + void compHash() override; +}; + +inline Ref +deepCopy(const Ref &p) { + return Ref::make( + p->nWarpBatch_, p->nWarpM_, p->nWarpN_, deepCopy(p->warpIdBatch_), + deepCopy(p->warpIdM_), deepCopy(p->warpIdN_), deepCopy(p->laneId_)); +} + +} // namespace freetensor + +#endif // CUTLASS_MICRO_KERNEL_PROPERTY_H diff --git a/include/hash.h b/include/hash.h index 33e1c9768..50d28490a 100644 --- a/include/hash.h +++ b/include/hash.h @@ -26,6 +26,7 @@ class Hasher { static size_t compHash(const Buffer &b); static size_t compHash(const ReductionItem &r); static size_t compHash(const ForProperty &p); + static size_t compHash(const CutlassMicroKernelProperty &p); // stmt static size_t compHash(const AnyNode &op); @@ -104,6 +105,8 @@ class HashComparator { const Ref &rhs) const; bool operator()(const Ref &lhs, const Ref &rhs) const; + bool operator()(const Ref &lhs, + const Ref &rhs) const; bool operator()(const AST &lhs, const AST &rhs) const; }; diff --git a/include/mutator.h b/include/mutator.h index 9b24d6af4..646f8f5ef 100644 --- a/include/mutator.h +++ b/include/mutator.h @@ -338,8 +338,20 @@ class Mutator { } virtual Stmt visit(const MatMul &op) { + Ref cutlassMicroKernelProperty = nullptr; + if (op->cutlassMicroKernelProperty_.isValid()) { + cutlassMicroKernelProperty = Ref::make( + op->cutlassMicroKernelProperty_->nWarpBatch_, + op->cutlassMicroKernelProperty_->nWarpM_, + op->cutlassMicroKernelProperty_->nWarpN_, + (*this)(op->cutlassMicroKernelProperty_->warpIdBatch_), + (*this)(op->cutlassMicroKernelProperty_->warpIdM_), + (*this)(op->cutlassMicroKernelProperty_->warpIdN_), + (*this)(op->cutlassMicroKernelProperty_->laneId_)); + } return makeMatMul( - op->backend_, (*this)(op->a_), (*this)(op->b_), (*this)(op->c_), + op->backend_, std::move(cutlassMicroKernelProperty), + (*this)(op->a_), (*this)(op->b_), (*this)(op->c_), (*this)(op->alpha_), (*this)(op->beta_), (*this)(op->m_), (*this)(op->k_), (*this)(op->n_), (*this)(op->lda_), (*this)(op->ldb_), (*this)(op->ldc_), (*this)(op->stridea_), diff --git a/include/pass/gpu/make_sync.h b/include/pass/gpu/make_sync.h index bef7403b1..5984b60d4 100644 --- a/include/pass/gpu/make_sync.h +++ b/include/pass/gpu/make_sync.h @@ -78,7 +78,7 @@ class MakeSync : public Mutator { const std::unordered_map &loop2thread_; std::vector deps_; std::unordered_map> - syncBeforeFor_, syncBeforeIf_; + syncBeforeFor_, syncBeforeIf_, syncBeforeLib_; std::unordered_map> branchSplittersThen_, branchSplittersElse_; LoopVariExprMap variantExprs_; @@ -188,6 +188,7 @@ class MakeSync : public Mutator { Stmt visitStmt(const Stmt &op) override; Stmt visit(const For &op) override; Stmt visit(const If &op) override; + Stmt visit(const MatMul &op) override; }; Stmt makeSync(const Stmt &op, const Ref &target); diff --git a/include/schedule/as_matmul.h b/include/schedule/as_matmul.h index 18657fcf0..528dbe3af 100644 --- a/include/schedule/as_matmul.h +++ b/include/schedule/as_matmul.h @@ -72,13 +72,32 @@ class AsMatMul : public SymbolTable { AnalyzeLinear analyzeLinear_; - bool done_ = false; + ID resultId_; + + // Public matching details + std::vector dimsABatch_, dimsBBatch_, dimsCBatch_, dimsAM_, dimsAK_, + dimsBK_, dimsBN_, dimsCM_, dimsCN_; + ID defIdA_, defIdB_, defIdC_; public: AsMatMul(const ID &loop, MatMulBackend backend) : loop_(loop), backend_(backend) {} - bool done() const { return done_; } + bool done() const { return resultId_.isValid(); } + const ID &resultId() const { return resultId_; } + + const auto &dimsABatch() const { return dimsABatch_; } + const auto &dimsBBatch() const { return dimsBBatch_; } + const auto &dimsCBatch() const { return dimsCBatch_; } + const auto &dimsAM() const { return dimsAM_; } + const auto &dimsAK() const { return dimsAK_; } + const auto &dimsBK() const { return dimsBK_; } + const auto &dimsBN() const { return dimsBN_; } + const auto &dimsCM() const { return dimsCM_; } + const auto &dimsCN() const { return dimsCN_; } + const auto &defIdA() const { return defIdA_; } + const auto &defIdB() const { return defIdB_; } + const auto &defIdC() const { return defIdC_; } private: const LinearExpr &analyzeLinear(const Expr &expr); diff --git a/include/schedule/lower_cutlass_micro_block.h b/include/schedule/lower_cutlass_micro_block.h new file mode 100644 index 000000000..cef269ddb --- /dev/null +++ b/include/schedule/lower_cutlass_micro_block.h @@ -0,0 +1,16 @@ +#ifndef FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H +#define FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H + +#include + +namespace freetensor { + +Stmt lowerCutlassMicroBlock(const Stmt &ast, const ID &matMulId, + const ID &defIdC, + const std::vector &dimsCBatch, + const std::vector &dimsCM, + const std::vector &dimsCN); + +} + +#endif // FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H diff --git a/include/stmt.h b/include/stmt.h index f320f8e0b..024465fdd 100644 --- a/include/stmt.h +++ b/include/stmt.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -465,14 +466,16 @@ enum class MatMulBackend : size_t { Mkl = 0, Cublas, Cutlass, + CutlassMicroBlock, // CUTLASS's micro kernel, invocable by a single + // thread-block + // ------ THE FOLLOWING BACKENDS CAN ONLY BE LOWERED TO ------ + CutlassMicroThread, // CUTLASS's micro kernel, invocable by a single thread // ---------------------------- NumBackends }; constexpr std::array matMulBackendNames = { - "mkl", - "cublas", - "cutlass", + "mkl", "cublas", "cutlass", "cutlass-micro-block", "cutlass-micro-thread", }; static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes); @@ -498,6 +501,8 @@ inline MatMulBackend parseMatMulBackend(const std::string &_str) { class MatMulNode : public StmtNode { public: MatMulBackend backend_; + SubTree + cutlassMicroKernelProperty_ = ChildOf{this}; // c_ = alpha_ * a_ * b_ + beta_ * c_ // a_ is an m_ * k_ matrix @@ -527,9 +532,11 @@ class MatMulNode : public StmtNode { }; typedef Ref MatMul; inline Stmt -makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c, - const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k, - const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc, +makeMatMul(MatMulBackend backend, + const Ref &cutlassMicroKernelProperty, + const Expr &a, const Expr &b, const Expr &c, const Expr &alpha, + const Expr &beta, const Expr &m, const Expr &k, const Expr &n, + const Expr &lda, const Expr &ldb, const Expr &ldc, const Expr &stridea, const Expr &strideb, const Expr &stridec, const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor, bool cIsRowMajor, const Stmt &equivalent, @@ -539,6 +546,7 @@ makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c, s->metadata() = metadata; s->setId(id); s->backend_ = backend; + s->cutlassMicroKernelProperty_ = cutlassMicroKernelProperty; s->a_ = a; s->b_ = b; s->c_ = c; diff --git a/runtime/micro_kernel/matmul/cutlass/gemm.h b/runtime/micro_kernel/matmul/cutlass/gemm.h new file mode 100644 index 000000000..c977f5101 --- /dev/null +++ b/runtime/micro_kernel/matmul/cutlass/gemm.h @@ -0,0 +1,26 @@ +#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H +#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H + +#if (defined(__CUDA_ARCH__)) // Device code + +#if (__CUDA_ARCH__ >= 800) +#include "gemm_sm80.h" +#else +#error "Unsupported architecture" +#endif + +#else // Host code + +// Only declaration is needed +template +__device__ void matmul_thread(const A_type *pA, const B_type *pB, C_type *accum, + int lda, int ldb, int stridea, int strideb, + int stridec, double alpha, double beta, + int warp_id_batch, int warp_id_m, int warp_id_n, + int lane_id); + +#endif + +#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H diff --git a/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h new file mode 100644 index 000000000..9f780274a --- /dev/null +++ b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h @@ -0,0 +1,149 @@ +/** + * This file is borrowed from + * https://github.com/nox-410/tvm.tl/blob/tl/src/tl/tl_templates/gemm_sm80.h + * under Apache Lincense, and modified for use. + */ + +#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H +#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H + +#include +#include +#include + +using cutlass::gemm::GemmShape; + +template +struct DispatchInstruction; + +template <> +struct DispatchInstruction { + using Shape = GemmShape<16, 8, 16>; +}; +template <> +struct DispatchInstruction { + using Shape = GemmShape<16, 8, 16>; +}; +template <> +struct DispatchInstruction { + using Shape = GemmShape<16, 8, 16>; +}; +template <> +struct DispatchInstruction { + using Shape = GemmShape<16, 8, 8>; +}; +template <> struct DispatchInstruction { + using Shape = GemmShape<8, 8, 4>; +}; +template <> struct DispatchInstruction { + using Shape = GemmShape<16, 8, 32>; +}; + +template struct DispatchSharedMemoryLayout; + +template <> struct DispatchSharedMemoryLayout { + using Layout = cutlass::layout::ColumnMajor; +}; +template <> struct DispatchSharedMemoryLayout { + using Layout = cutlass::layout::RowMajor; +}; + +template +class GemmTensorOp { + public: + using A_type = + typename std::conditional::value, + cutlass::tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional::value, + cutlass::tfloat32_t, A_type_raw>::type; + using C_type = C_type_raw; + using InstructionShape = + typename DispatchInstruction::Shape; + using SMemLayoutA = typename DispatchSharedMemoryLayout::Layout; + using SMemLayoutB = typename DispatchSharedMemoryLayout::Layout; + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape, 32, A_type, cutlass::layout::RowMajor, B_type, + cutlass::layout::ColumnMajor, C_type, cutlass::layout::RowMajor, + cutlass::arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; + + static_assert(Shape::kM % num_warp_m == 0); + static_assert(Shape::kN % num_warp_n == 0); + + using MmaWarp = typename cutlass::gemm::warp::MmaTensorOp< + GemmShape, + A_type, SMemLayoutA, B_type, SMemLayoutB, C_type, + cutlass::layout::RowMajor, Policy, 1, + true /* accumulate in row major */>; + + using TensorRefA = typename MmaWarp::IteratorA::TensorRef; + using TensorRefB = typename MmaWarp::IteratorB::TensorRef; + using FragmentA = typename MmaWarp::FragmentA; + using FragmentB = typename MmaWarp::FragmentB; + using FragmentC = typename MmaWarp::FragmentC; + using IteratorA = typename MmaWarp::IteratorA; + using IteratorB = typename MmaWarp::IteratorB; + + static_assert(Shape::kK % InstructionShape::kK == 0); + static int constexpr kKgroups = Shape::kK / InstructionShape::kK; + + static CUTLASS_DEVICE void body(const A_type_raw *pA, const B_type_raw *pB, + FragmentC &accum, int lda, int ldb, + double alpha, double beta, + const int warp_idx_m, const int warp_idx_n, + const int lane_id) { + MmaWarp mma_op; + FragmentA frag_A; + FragmentB frag_B; + const TensorRefA ref_A((A_type *)pA, lda); + const TensorRefB ref_B((B_type *)pB, ldb); + IteratorA iter_A(ref_A, lane_id); + IteratorB iter_B(ref_B, lane_id); + iter_A.add_tile_offset({warp_idx_m, 0}); + iter_B.add_tile_offset({0, warp_idx_n}); + + // TODO: Check all cases of alpha and beta + // TODO: Static checking of alpha and beta + if (beta == 0) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentC::kElements; i++) { + accum[i] = 0; + } + } else { + assert(beta == 1); + } + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_A.load(frag_A); + iter_B.load(frag_B); + ++iter_A; + ++iter_B; + mma_op(accum, frag_A, frag_B, accum); + } + } +}; + +template +CUTLASS_DEVICE void matmul_thread(const A_type *pA, const B_type *pB, + C_type *accum, int lda, int ldb, int stridea, + int strideb, int stridec, double alpha, + double beta, int warp_id_batch, int warp_id_m, + int warp_id_n, int lane_id) { + using MMA = GemmTensorOp, num_warp_m, num_warp_n, + trans_A, trans_B, A_type, B_type, C_type>; + using FragmentC = typename MMA::FragmentC; + MMA::body(pA + warp_id_batch * stridea, pB + warp_id_batch * strideb, + *(FragmentC *)(accum /* no thread offset */), lda, ldb, alpha, + beta, warp_id_m, warp_id_n, lane_id); +} + +#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H diff --git a/src/analyze/comp_unique_bounds_pb.cc b/src/analyze/comp_unique_bounds_pb.cc index 19f34cb38..424c50a9e 100644 --- a/src/analyze/comp_unique_bounds_pb.cc +++ b/src/analyze/comp_unique_bounds_pb.cc @@ -85,10 +85,10 @@ std::tuple CompUniqueBoundsPB::Bound::lowerUpperDiffExpr() const { PBSet l = bound_.hasLowerBound(0) ? lexmin(bound_) : PBSet(); PBSet u = bound_.hasUpperBound(0) ? lexmax(bound_) : PBSet(); - PBSet diff = - l.isValid() && u.isValid() - ? apply(cartesianProduct(u, l), PBMap(*ctx_, "{[u, l] -> [u - l]}")) - : PBSet(); + PBSet diff = l.isValid() && u.isValid() + ? coalesce(apply(cartesianProduct(u, l), + PBMap(*ctx_, "{[u, l] -> [u - l]}"))) + : PBSet(); return {l.isValid() ? translateBoundFunc(*ctx_, l, *demangleMap_) : nullptr, u.isValid() ? translateBoundFunc(*ctx_, u, *demangleMap_) : nullptr, diff.isValid() ? translateBoundFunc(*ctx_, diff, *demangleMap_) diff --git a/src/codegen/code_gen_cuda.cc b/src/codegen/code_gen_cuda.cc index 4c98343bd..e5e27b710 100644 --- a/src/codegen/code_gen_cuda.cc +++ b/src/codegen/code_gen_cuda.cc @@ -1010,6 +1010,65 @@ void CodeGenCUDA::visit(const MatMul &op) { break; } + case MatMulBackend::CutlassMicroThread: { + if (!thisOpInKernel) { + throw InvalidProgram( + "A MatMul's micro kernel can only be called inside a kernel"); + } + + ASSERT(op->cIsRowMajor_); + + auto &&prop = op->cutlassMicroKernelProperty_; + + makeIndent(); + os() << "matmul_thread<"; + (*this)(m); + os() << ", "; + (*this)(n); + os() << ", "; + (*this)(k); + os() << ", " << prop->nWarpBatch_ << ", " << prop->nWarpM_ << ", " + << prop->nWarpN_ << ", " << (transA ? "true" : "false") << ", " + << (transB ? "true" : "false") << ", " + << genCUTLASSType(a->dtype()) << ", " << genCUTLASSType(b->dtype()) + << ", " << genCUTLASSType(c->dtype()) << ">((const " + << genCUTLASSType(a->dtype()) << "*)&("; + (*this)(a); + os() << "), (const " << genCUTLASSType(b->dtype()) << "*)&("; + (*this)(b); + os() << "), (" << genCUTLASSType(c->dtype()) << "*)&("; + (*this)(c); + os() << "), "; + (*this)(lda); + os() << ", "; + (*this)(ldb); + os() << ", "; + (*this)(stridea); + os() << ", "; + (*this)(strideb); + os() << ", "; + (*this)(stridec); + os() << ", "; + (*this)(op->alpha_); + os() << ", "; + (*this)(op->beta_); + os() << ", "; + (*this)(prop->warpIdBatch_); + os() << ", "; + (*this)(prop->warpIdM_); + os() << ", "; + (*this)(prop->warpIdN_); + os() << ", "; + (*this)(prop->laneId_); + os() << ");" << std::endl; + + neededMicroKernels_.emplace_back("matmul/cutlass/gemm.h"); + break; + } + + case MatMulBackend::CutlassMicroBlock: + ERROR("CutlassMicroBlock should be lowered before codegen"); + default: inMatmul_ = false; throw InvalidProgram("MatMul backend " + @@ -1033,17 +1092,21 @@ NativeCode codeGenCUDA(const Func &func, const Ref &_target) { visitor(op); visitor.endBlock(); - const char *header = R"~~~( + std::string header = R"~~~( #include extern __shared__ uint8_t __shmem[]; extern "C" { )~~~"; - const char *tailer = R"~~~( + std::string tailer = R"~~~( } )~~~"; + for (auto &&item : visitor.neededMicroKernels()) { + header = "#include \"micro_kernel/" + item + "\"\n" + header; + } + auto body = visitor.toString([&](const CodeGenCUDA::Stream &stream) { if (stream.name_ == "default") { std::string s = diff --git a/src/cutlass_micro_kernel_property.cc b/src/cutlass_micro_kernel_property.cc new file mode 100644 index 000000000..31c79a389 --- /dev/null +++ b/src/cutlass_micro_kernel_property.cc @@ -0,0 +1,8 @@ +#include +#include + +namespace freetensor { + +void CutlassMicroKernelProperty::compHash() { hash_ = Hasher::compHash(*this); } + +} // namespace freetensor diff --git a/src/hash.cc b/src/hash.cc index 30ccb205e..8febeac10 100644 --- a/src/hash.cc +++ b/src/hash.cc @@ -1,8 +1,27 @@ +#include + #include #include namespace freetensor { +template +static std::optional trivialCompare(const Ref &lhs, + const Ref &rhs) { + if (lhs == rhs) { // alias or nullptr + return true; + } + if (lhs.isValid() != rhs.isValid()) { + return false; + } + + if (lhs->hash() != rhs->hash()) { + return false; + } + + return std::nullopt; +} + size_t Hasher::compHash(const Tensor &t) { size_t h = (-1 * K1 + B1) % P; for (auto &&dim : t.shape()) { @@ -48,6 +67,18 @@ size_t Hasher::compHash(const ForProperty &p) { return (h * K3 + B3) % P; } +size_t Hasher::compHash(const CutlassMicroKernelProperty &p) { + size_t h = (-1 * K1 + B1) % P; + h = ((h + std::hash{}(p.nWarpBatch_)) * K2 + B2) % P; + h = ((h + std::hash{}(p.nWarpM_)) * K2 + B2) % P; + h = ((h + std::hash{}(p.nWarpN_)) * K2 + B2) % P; + h = ((h + p.warpIdBatch_->hash()) * K2 + B2) % P; + h = ((h + p.warpIdM_->hash()) * K2 + B2) % P; + h = ((h + p.warpIdN_->hash()) * K2 + B2) % P; + h = ((h + p.laneId_->hash()) * K2 + B2) % P; + return (h * K3 + B3) % P; +} + size_t Hasher::compHash(const AnyNode &op) { size_t h = ((size_t)op.nodeType() * K1 + B1) % P; return (h * K3 + B3) % P; @@ -151,6 +182,9 @@ size_t Hasher::compHash(const EvalNode &op) { size_t Hasher::compHash(const MatMulNode &op) { size_t h = ((size_t)op.nodeType() * K1 + B1) % P; h = ((h + std::hash()(op.backend_)) * K2 + B2) % P; + if (op.cutlassMicroKernelProperty_.isValid()) { + h = ((h + op.cutlassMicroKernelProperty_->hash()) * K2 + B2) % P; + } h = ((h + op.equivalent_->hash()) * K2 + B2) % P; return (h * K3 + B3) % P; } @@ -416,6 +450,10 @@ bool HashComparator::compare(const MatMul &lhs, const MatMul &rhs) const { if (lhs->backend_ != rhs->backend_) { return false; } + if (!(*this)(lhs->cutlassMicroKernelProperty_, + rhs->cutlassMicroKernelProperty_)) { + return false; + } return (*this)(lhs->equivalent_, rhs->equivalent_); } @@ -536,6 +574,10 @@ bool HashComparator::compare(const LoadAtVersion &lhs, bool HashComparator::operator()(const Ref &lhs, const Ref &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; + } + if (lhs->shape().size() != rhs->shape().size()) { return false; } @@ -552,6 +594,10 @@ bool HashComparator::operator()(const Ref &lhs, bool HashComparator::operator()(const Ref &lhs, const Ref &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; + } + if (!(*this)(lhs->tensor(), rhs->tensor())) { return false; } @@ -566,6 +612,10 @@ bool HashComparator::operator()(const Ref &lhs, bool HashComparator::operator()(const Ref &lhs, const Ref &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; + } + if (lhs->op_ != rhs->op_) { return false; } @@ -593,6 +643,10 @@ bool HashComparator::operator()(const Ref &lhs, bool HashComparator::operator()(const Ref &lhs, const Ref &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; + } + if (lhs->parallel_ != rhs->parallel_) { return false; } @@ -624,17 +678,41 @@ bool HashComparator::operator()(const Ref &lhs, return true; } -bool HashComparator::operator()(const AST &lhs, const AST &rhs) const { - if (lhs == rhs) { // alias or nullptr - return true; +bool HashComparator::operator()( + const Ref &lhs, + const Ref &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; } - if (lhs.isValid() != rhs.isValid()) { + + if (lhs->nWarpBatch_ != rhs->nWarpBatch_) { return false; } - - if (lhs->hash() != rhs->hash()) { + if (lhs->nWarpM_ != rhs->nWarpM_) { return false; } + if (lhs->nWarpN_ != rhs->nWarpN_) { + return false; + } + if (!(*this)(lhs->warpIdBatch_, rhs->warpIdBatch_)) { + return false; + } + if (!(*this)(lhs->warpIdM_, rhs->warpIdM_)) { + return false; + } + if (!(*this)(lhs->warpIdN_, rhs->warpIdN_)) { + return false; + } + if (!(*this)(lhs->laneId_, rhs->laneId_)) { + return false; + } + return true; +} + +bool HashComparator::operator()(const AST &lhs, const AST &rhs) const { + if (auto &&flag = trivialCompare(lhs, rhs); flag.has_value()) { + return *flag; + } if (lhs->nodeType() != rhs->nodeType()) { return false; diff --git a/src/pass/gpu/make_sync.cc b/src/pass/gpu/make_sync.cc index f456a2ebd..8aab9dea4 100644 --- a/src/pass/gpu/make_sync.cc +++ b/src/pass/gpu/make_sync.cc @@ -265,6 +265,8 @@ Stmt MakeSync::visitStmt(const Stmt &op) { // case, where we need the `then` case AND the `else` case to // sync on ONE sync point whereToInsert = ctx; + } else if (ctx->nodeType() == ASTNodeType::MatMul) { + whereToInsert = ctx; } } if (!whereToInsert.isValid()) { @@ -276,13 +278,19 @@ Stmt MakeSync::visitStmt(const Stmt &op) { syncBeforeFor_.at(whereToInsert->id()).second)) { syncBeforeFor_[whereToInsert->id()] = {sync, !needSyncThreads}; } - } else { - ASSERT(whereToInsert->nodeType() == ASTNodeType::If); + } else if (whereToInsert->nodeType() == ASTNodeType::If) { if (!syncBeforeIf_.count(whereToInsert->id()) || (needSyncThreads && syncBeforeIf_.at(whereToInsert->id()).second)) { syncBeforeIf_[whereToInsert->id()] = {sync, !needSyncThreads}; } + } else { + ASSERT(whereToInsert->nodeType() == ASTNodeType::MatMul); + if (!syncBeforeLib_.count(whereToInsert->id()) || + (needSyncThreads && + syncBeforeLib_.at(whereToInsert->id()).second)) { + syncBeforeLib_[whereToInsert->id()] = {sync, !needSyncThreads}; + } } for (CrossThreadDep &dep : deps_) { @@ -335,8 +343,8 @@ Stmt MakeSync::visit(const For &_op) { } } } - if (syncBeforeFor_.count(op->id())) { - auto &&[sync, isSyncWarp] = syncBeforeFor_.at(op->id()); + if (auto it = syncBeforeFor_.find(op->id()); it != syncBeforeFor_.end()) { + auto &&[sync, isSyncWarp] = it->second; markSyncForSplitting(_op, sync, isSyncWarp); return makeStmtSeq({sync, op}); } @@ -398,8 +406,8 @@ Stmt MakeSync::visit(const If &op) { op->metadata(), op->id(), op->debugBlame()); } - if (syncBeforeIf_.count(op->id())) { - auto &&[sync, isSyncWarp] = syncBeforeIf_.at(op->id()); + if (auto it = syncBeforeIf_.find(op->id()); it != syncBeforeIf_.end()) { + auto &&[sync, isSyncWarp] = it->second; markSyncForSplitting(op, sync, isSyncWarp); return makeStmtSeq({sync, ret}); } @@ -407,6 +415,15 @@ Stmt MakeSync::visit(const If &op) { return ret; } +Stmt MakeSync::visit(const MatMul &op) { + auto ret = BaseClass::visit(op); + if (auto it = syncBeforeLib_.find(op->id()); it != syncBeforeLib_.end()) { + auto &&[sync, isSyncWarp] = it->second; + return makeStmtSeq({sync, ret}); + } + return ret; +} + static Stmt doMakeSync(const Stmt &_op, const Ref &target) { auto op = constFold(_op); diff --git a/src/pass/simplify.cc b/src/pass/simplify.cc index 30d62a716..d019cba7d 100644 --- a/src/pass/simplify.cc +++ b/src/pass/simplify.cc @@ -820,7 +820,7 @@ Expr SimplifyPass::visit(const IfExpr &_op) { elseLin.coeff_[j].a_->hash() < thenLin.coeff_[i].a_->hash()) { j++; } - if (thenLin.coeff_[i].k_ == elseLin.coeff_[j].k_ && + if (j < n && thenLin.coeff_[i].k_ == elseLin.coeff_[j].k_ && HashComparator{}(thenLin.coeff_[i].a_, elseLin.coeff_[j].a_)) { common.coeff_.emplace_back(thenLin.coeff_[i]); thenLin.coeff_[i].k_ = elseLin.coeff_[j].k_ = 0; diff --git a/src/schedule/as_matmul.cc b/src/schedule/as_matmul.cc index 5ef6ba2e4..61abadee4 100644 --- a/src/schedule/as_matmul.cc +++ b/src/schedule/as_matmul.cc @@ -5,6 +5,7 @@ #include #include #include +#include namespace freetensor { @@ -202,14 +203,16 @@ Stmt AsMatMul::visit(const For &op) { } else { beta = makeIntConst(1); } - ret = makeMatMul(backend_, a_, b_, c_, alpha, beta, m_, k_, n_, lda_, - ldb_, ldc_, stridea_, strideb_, stridec_, batchSize_, - aIsRowMajor_, bIsRowMajor_, cIsRowMajor_, ret); + + ret = makeMatMul(backend_, nullptr, a_, b_, c_, alpha, beta, m_, k_, n_, + lda_, ldb_, ldc_, stridea_, strideb_, stridec_, + batchSize_, aIsRowMajor_, bIsRowMajor_, cIsRowMajor_, + ret, makeMetadata("as_matmul", op)); + resultId_ = ret->id(); for (auto &&def : innerDefs_) { ret = makeVarDef(def->name_, def->buffer_, def->viewOf_, ret, def->pinned_, def->metadata(), def->id()); } - done_ = true; return ret; } else { ASSERT(!outerDefs_.count(op->iter_)); @@ -291,50 +294,52 @@ Stmt AsMatMul::visit(const ReduceTo &_op) { nAxes[i] = !usedByA[i] && usedByB[i] && usedByC[i]; } - ID idA = def(loadA->var_)->id(); - ID idB = def(loadB->var_)->id(); - ID idC = def(op->var_)->id(); + defIdA_ = def(loadA->var_)->id(); + defIdB_ = def(loadB->var_)->id(); + defIdC_ = def(op->var_)->id(); - checkSameOrderOrRetry(idA, orderA, batchAxes, idB, orderB, batchAxes, + checkSameOrderOrRetry(defIdA_, orderA, batchAxes, defIdB_, orderB, + batchAxes, "Order of each indices in the batch axis should " "be the same in each matrices"); - checkSameOrderOrRetry(idA, orderA, batchAxes, idC, orderC, batchAxes, + checkSameOrderOrRetry(defIdA_, orderA, batchAxes, defIdC_, orderC, + batchAxes, "Order of each indices in the batch axis should " "be the same in each matrices"); - checkSameOrderOrRetry(idA, orderA, mAxes, idC, orderC, mAxes, + checkSameOrderOrRetry(defIdA_, orderA, mAxes, defIdC_, orderC, mAxes, "Order of each indices in the m axis should be " "the same in each matrices"); - checkSameOrderOrRetry(idA, orderA, kAxes, idB, orderB, kAxes, + checkSameOrderOrRetry(defIdA_, orderA, kAxes, defIdB_, orderB, kAxes, "Order of each indices in the k axis should be " "the same in each matrices"); - checkSameOrderOrRetry(idB, orderB, nAxes, idC, orderC, nAxes, + checkSameOrderOrRetry(defIdB_, orderB, nAxes, defIdC_, orderC, nAxes, "Order of each indices in the n axis should be " "the same in each matrices"); if (foundInit_) { checkSameOrderNoRetry( - idC, orderInit_, batchAxes, idC, orderC, batchAxes, + defIdC_, orderInit_, batchAxes, defIdC_, orderC, batchAxes, "Order of each indices in the batch axis should be the same in " "initialization and reduction"); checkSameOrderNoRetry( - idC, orderInit_, mAxes, idC, orderC, mAxes, + defIdC_, orderInit_, mAxes, defIdC_, orderC, mAxes, "Order of each indices in the m axis should be the same in " "initialization and reduction"); checkSameOrderNoRetry( - idC, orderInit_, nAxes, idC, orderC, nAxes, + defIdC_, orderInit_, nAxes, defIdC_, orderC, nAxes, "Order of each indices in the n axis should be the same in " "initialization and reduction"); } // Find out which TENSOR DIMENSIONS are used - std::vector dimsABatch = findDimsUsed(loadA, batchAxes); - std::vector dimsBBatch = findDimsUsed(loadB, batchAxes); - std::vector dimsCBatch = findDimsUsed(op, batchAxes); - std::vector dimsAM = findDimsUsed(loadA, mAxes); - std::vector dimsAK = findDimsUsed(loadA, kAxes); - std::vector dimsBK = findDimsUsed(loadB, kAxes); - std::vector dimsBN = findDimsUsed(loadB, nAxes); - std::vector dimsCM = findDimsUsed(op, mAxes); - std::vector dimsCN = findDimsUsed(op, nAxes); + auto &dimsABatch = dimsABatch_ = findDimsUsed(loadA, batchAxes); + auto &dimsBBatch = dimsBBatch_ = findDimsUsed(loadB, batchAxes); + auto &dimsCBatch = dimsCBatch_ = findDimsUsed(op, batchAxes); + auto &dimsAM = dimsAM_ = findDimsUsed(loadA, mAxes); + auto &dimsAK = dimsAK_ = findDimsUsed(loadA, kAxes); + auto &dimsBK = dimsBK_ = findDimsUsed(loadB, kAxes); + auto &dimsBN = dimsBN_ = findDimsUsed(loadB, nAxes); + auto &dimsCM = dimsCM_ = findDimsUsed(op, mAxes); + auto &dimsCN = dimsCN_ = findDimsUsed(op, nAxes); Expr strideAM, strideAK, strideBK, strideBN, strideCM, strideCN; std::tie(batchSize_, stridea_) = findLenAndStride(loadA, dimsABatch); @@ -354,7 +359,7 @@ Stmt AsMatMul::visit(const ReduceTo &_op) { lda_ = strideAK; } else { retryReorderingBack( - idA, dimsAK, + defIdA_, dimsAK, "Either m or k dimension of a should be 1-strided"); } if (isIntConst1(strideBN)) { @@ -365,7 +370,7 @@ Stmt AsMatMul::visit(const ReduceTo &_op) { ldb_ = strideBN; } else { retryReorderingBack( - idB, dimsBN, + defIdB_, dimsBN, "Either k or n dimension of b should be 1-strided"); } if (isIntConst1(strideCN)) { @@ -376,7 +381,7 @@ Stmt AsMatMul::visit(const ReduceTo &_op) { ldc_ = strideCN; } else { retryReorderingBack( - idC, dimsCN, + defIdC_, dimsCN, "Either m or n dimension of c should be 1-strided"); } @@ -401,17 +406,17 @@ Stmt AsMatMul::visit(const ReduceTo &_op) { stridec_ = makeMul(ldc_, cIsRowMajor_ ? m_ : n_); } else { if (!inOrder(dimsABatch, dimsAM) || !inOrder(dimsABatch, dimsAK)) { - retryReorderingFront(idA, dimsABatch, + retryReorderingFront(defIdA_, dimsABatch, "BLAS requires batch dimensions to be out " "of matrix dimensions in A"); } if (!inOrder(dimsBBatch, dimsBK) || !inOrder(dimsABatch, dimsBN)) { - retryReorderingFront(idB, dimsBBatch, + retryReorderingFront(defIdB_, dimsBBatch, "BLAS requires batch dimensions to be out " "of matrix dimensions in B"); } if (!inOrder(dimsCBatch, dimsCM) || !inOrder(dimsABatch, dimsCN)) { - retryReorderingFront(idC, dimsCBatch, + retryReorderingFront(defIdC_, dimsCBatch, "BLAS requires batch dimensions to be out " "of matrix dimensions in C"); } @@ -442,11 +447,25 @@ Stmt asMatMul(const Stmt &_ast, const ID &loop, MatMulBackend backend) { if (!mutator.done()) { throw InvalidSchedule(FT_MSG << loop << " not found"); } + + if (backend == MatMulBackend::CutlassMicroBlock) { + ast = lowerCutlassMicroBlock(ast, mutator.resultId(), mutator.defIdC(), + mutator.dimsCBatch(), mutator.dimsCM(), + mutator.dimsCN()); + } + return ast; } void Schedule::asMatMul(const ID &loop, AsMatMulMode mode, const Ref &target, MatMulBackend backend) { + if (backend == MatMulBackend::CutlassMicroBlock && + mode != AsMatMulMode::TryVarReorder) { + throw InvalidSchedule( + ast(), + FT_MSG << "cutlass_micro_block backend of as_matmul requires " + "TryVarReorder mode"); + } beginTransaction(); while (true) { auto log = appendLog( diff --git a/src/schedule/lower_cutlass_micro_block.cc b/src/schedule/lower_cutlass_micro_block.cc new file mode 100644 index 000000000..c54c14189 --- /dev/null +++ b/src/schedule/lower_cutlass_micro_block.cc @@ -0,0 +1,362 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace freetensor { + +namespace { + +bool isPowerOfTwo(int x) { return (x & (x - 1)) == 0; } + +class FixTransposeAndGetPartition : public Mutator { + ID matMulId_; + int64_t nWarpBatch_ = 0, nWarpM_ = 0, nWarpN_ = 0; + + public: + FixTransposeAndGetPartition(const ID &matMulId) : matMulId_(matMulId) {} + + auto nWarpBatch() const { return nWarpBatch_; } + auto nWarpM() const { return nWarpM_; } + auto nWarpN() const { return nWarpN_; } + + private: + std::tuple computeWarpPartition(int64_t batch, int64_t m, + int64_t n, int64_t k, + int nWarp) { + // Try to achieve the following goal in priority: + // + // 1. There should not be wasted warps, which means `nWarpM` and + // `nWarpN` should divide `nWarp`. + // 2. Use as more warps for the batch dimension as possible. + // 3. `m / nWarpM` and `n / nWarpN` should be as close as possible, to + // make the reuse in registers more efficient. + + int nWarpBatch = std::gcd(nWarp, batch); + nWarp /= nWarpBatch; + + int nWarpM = 1, nWarpN = 1; + if (isPowerOfTwo(nWarp)) { + for (int i = 1; i < nWarp; i <<= 1) { + bool mDivisible = m % 2 == 0; + bool nDivisible = n % 2 == 0; + if (mDivisible && nDivisible) { + if (m / nWarpM > n / nWarpN) { + nWarpM *= 2; + } else { + nWarpN *= 2; + } + } else if (mDivisible) { + nWarpM *= 2; + } else if (nDivisible) { + nWarpN *= 2; + } else { + throw InvalidSchedule( + "Cannot compute warp partition for m = " + + std::to_string(m) + ", n = " + std::to_string(n) + + ", nWarp = " + std::to_string(nWarp)); + } + } + } else { + ASSERT(false); + } + + return {nWarpBatch, nWarpM, nWarpN}; + } + + protected: + Stmt visit(const MatMul &_op) override { + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == ASTNodeType::MatMul); + auto op = __op.as(); + + if (op->id() == matMulId_) { + ASSERT(op->backend_ == MatMulBackend::CutlassMicroBlock); + + // C is only supported for densely packed row-major layout in + // registers + if (!op->cIsRowMajor_) { + op->aIsRowMajor_ = !op->aIsRowMajor_; + op->bIsRowMajor_ = !op->bIsRowMajor_; + op->cIsRowMajor_ = true; + std::swap(op->aIsRowMajor_, op->bIsRowMajor_); + std::swap(op->a_, op->b_); + std::swap(op->lda_, op->ldb_); + std::swap(op->stridea_, op->strideb_); + std::swap(op->n_, op->m_); + } + + // For a single `MatMul`, `nWarp` parameter affects the performance, + // but the effect is limited. However, when there are multiple + // `MatMul`s, or when there are user's threaded code in the same + // kernel, it is critical to make `nWarp` of each of them + // consistent, to avoid wasting warps. TODO: find a way to adjust + // `nWarp` across different `MatMul`s. + const int nWarp = 4; // 128 threads + + int64_t batch, m, n, k; + if (op->batchSize_->nodeType() == ASTNodeType::IntConst) { + batch = op->batchSize_.as()->val_; + } else { + throw InvalidSchedule( + "Dynamic size of `batchSize` is not " + "supported for CutlassMicroBlock backend"); + } + if (op->m_->nodeType() == ASTNodeType::IntConst) { + m = op->m_.as()->val_; + } else { + throw InvalidSchedule( + "Dyanmic size of `m` is not supported for " + "CutlassMicroBlock backend"); + } + if (op->n_->nodeType() == ASTNodeType::IntConst) { + n = op->n_.as()->val_; + } else { + throw InvalidSchedule( + "Dyanmic size of `n` is not supported for " + "CutlassMicroBlock backend"); + } + if (op->k_->nodeType() == ASTNodeType::IntConst) { + k = op->k_.as()->val_; + } else { + throw InvalidSchedule( + "Dyanmic size of `k` is not supported for " + "CutlassMicroBlock backend"); + } + + std::tie(nWarpBatch_, nWarpM_, nWarpN_) = + computeWarpPartition(batch, m, n, k, nWarp); + } + + return op; + } +}; + +class LowerCutlassMicroBlock : public SymbolTable { + typedef SymbolTable BaseClass; + + ID matMulId_; + int64_t nWarpBatch_ = 0, nWarpM_ = 0, nWarpN_ = 0; + + Ref prop_; + bool inMicroKernel_ = false; + + public: + LowerCutlassMicroBlock(const ID &matMulId, int64_t nWarpBatch, + int64_t nWarpM, int64_t nWarpN) + : matMulId_(matMulId), nWarpBatch_(nWarpBatch), nWarpM_(nWarpM), + nWarpN_(nWarpN) {} + + private: + template Stmt guardWriteByPartition(const T &op) { + auto ret = BaseClass::visit(op); + if (inMicroKernel_) { + int nDimsCAll = op->indices_.size(); + ASSERT(nDimsCAll >= + 9); // See comments in `lowerCutlassMicroBlock` below + auto batchInWarpPartition = + makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_); + auto mInWarpPartition = + makeEQ(op->indices_[nDimsCAll - 7], prop_->warpIdM_); + auto nInWarpPartition = + makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdN_); + auto mInThreadPartition = + makeEQ(op->indices_[nDimsCAll - 5], + makeFloorDiv(prop_->laneId_, makeIntConst(4))); + auto nInThreadPartition = + makeEQ(op->indices_[nDimsCAll - 2], + makeMod(prop_->laneId_, makeIntConst(4))); + + ret = makeIf( + makeLAnd(makeLAnd(batchInWarpPartition, + makeLAnd(mInWarpPartition, nInWarpPartition)), + makeLAnd(mInThreadPartition, nInThreadPartition)), + ret); + } + return ret; + } + + protected: + using BaseClass::visit; + + Stmt visit(const MatMul &_op) override { + if (_op->id() == matMulId_) { + if (inMicroKernel_) { + throw InvalidSchedule("Micro kernels cannot nest each other"); + } + + // Here we use `threadIdx.x` for threads in a warp, and + // `threadIdx.y` for warps, because putting everthing into a single + // `threadIdx.x` will make the expressions to complicated to solve. + // However, this brings a challenge when fusing parts of a program + // with different different thread mappings. We need to come up with + // better way to parallelize other parts of the program according to + // the thread mapping here. (TODO) + Expr warpId = makeVar(".matmul.threadIdx.y"); + Expr laneId = makeVar(".matmul.threadIdx.x"); + Expr warpIdBatch = + makeFloorDiv(warpId, makeIntConst(nWarpN_ * nWarpM_)); + Expr warpIdM = makeMod(makeFloorDiv(warpId, makeIntConst(nWarpN_)), + makeIntConst(nWarpM_)); + Expr warpIdN = makeMod(warpId, makeIntConst(nWarpN_)); + + prop_ = Ref::make( + nWarpBatch_, nWarpM_, nWarpN_, warpIdBatch, warpIdM, warpIdN, + laneId); + + inMicroKernel_ = true; + auto __op = BaseClass::visit(_op); + ASSERT(__op->nodeType() == ASTNodeType::MatMul); + auto op = __op.as(); + inMicroKernel_ = false; + + // point the c_ pointer to the starting address of each thread + ASSERT(op->c_->nodeType() == ASTNodeType::Load); + auto c = op->c_.as(); + int nDimsCAll = c->indices_.size(); + ASSERT(nDimsCAll >= + 9); // See comments in `lowerCutlassMicroBlock` below + c->indices_[nDimsCAll - 9] = warpIdBatch; + c->indices_[nDimsCAll - 7] = warpIdM; + c->indices_[nDimsCAll - 5] = makeFloorDiv(laneId, makeIntConst(4)); + c->indices_[nDimsCAll - 4] = warpIdN; + c->indices_[nDimsCAll - 2] = makeMod(laneId, makeIntConst(4)); + + op->backend_ = MatMulBackend::CutlassMicroThread; + op->cutlassMicroKernelProperty_ = prop_; + + auto metadata = std::move(op->metadata()); + op->metadata() = nullptr; + + const int warpSize = 32; + Stmt ret = op; + ret = makeFor(".matmul.threadIdx.x", makeIntConst(0), + makeIntConst(warpSize), makeIntConst(1), + makeIntConst(warpSize), + Ref::make()->withParallel(threadIdxX), + std::move(ret)); + ret = makeFor(".matmul.threadIdx.y", makeIntConst(0), + makeIntConst(nWarpBatch_ * nWarpM_ * nWarpN_), + makeIntConst(1), + makeIntConst(nWarpBatch_ * nWarpM_ * nWarpN_), + Ref::make()->withParallel(threadIdxY), + std::move(ret), std::move(metadata)); + return ret; + } else { + return BaseClass::visit(_op); + } + } + + Stmt visit(const Store &op) override { return guardWriteByPartition(op); } + Stmt visit(const ReduceTo &op) override { + return guardWriteByPartition(op); + } +}; + +} // Anonymous namespace + +Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, + const ID &defIdC, + const std::vector &dimsCBatch, + const std::vector &dimsCM, + const std::vector &dimsCN) { + // Get partition info + FixTransposeAndGetPartition fixTransposeAndGetPartition{matMulId}; + auto ast = fixTransposeAndGetPartition(_ast); + auto nWarpBatch = fixTransposeAndGetPartition.nWarpBatch(); + auto nWarpM = fixTransposeAndGetPartition.nWarpM(); + auto nWarpN = fixTransposeAndGetPartition.nWarpN(); + + // Partition C to each threads by layout-manipulation schedules. We have + // checked we are in TryVarReorder mode in schedule/as_matmul.cc. The + // resulting layout will be [ + // ...: other leading dims, + // -9: batch warps, + // -8: batch serial, + // -7: m warps, + // -6: m 8-tiles, + // -5: m threads, + // -4: n warps, + // -3: n 8-tiles, + // -2: n threads, + // -1: n 2-tiles + // ] + // + // See + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // for 8x8 partition inside warps + int nDimsCBatch = std::ranges::count(dimsCBatch, true); + int nDimsCM = std::ranges::count(dimsCM, true); + int nDimsCN = std::ranges::count(dimsCN, true); + if (!std::all_of(dimsCN.end() - nDimsCN, dimsCN.end(), + [](bool b) { return b; })) { + throw InvalidSchedule( + FT_MSG << "Invalid C layout for cutlass_micro_block backend"); + } + if (!std::all_of(dimsCM.end() - nDimsCN - nDimsCM, dimsCM.end() - nDimsCN, + [](bool b) { return b; })) { + throw InvalidSchedule( + FT_MSG << "Invalid C layout for cutlass_micro_block backend"); + } + if (!std::all_of(dimsCBatch.end() - nDimsCN - nDimsCM - nDimsCBatch, + dimsCBatch.end() - nDimsCN - nDimsCM, + [](bool b) { return b; })) { + throw InvalidSchedule( + FT_MSG << "Invalid C layout for cutlass_micro_block backend"); + } + int nDimsCAll = (int)dimsCBatch.size(); + int nDimsCOthers = nDimsCAll - nDimsCBatch - nDimsCM - nDimsCN; + if (nDimsCN > 1) { + for (int i = nDimsCN - 2; i >= 0; i--) { + ast = + varMerge(ast, defIdC, nDimsCOthers + nDimsCBatch + nDimsCM + i); + } + } else if (nDimsCN == 0) { + ast = varUnsqueeze(ast, defIdC, nDimsCOthers + nDimsCBatch + nDimsCM); + } + if (nDimsCM > 1) { + for (int i = nDimsCM - 2; i >= 0; i--) { + ast = varMerge(ast, defIdC, nDimsCOthers + nDimsCBatch + i); + } + } else if (nDimsCM == 0) { + ast = varUnsqueeze(ast, defIdC, nDimsCOthers + nDimsCBatch); + } + if (nDimsCBatch > 1) { + for (int i = nDimsCBatch - 2; i >= 0; i--) { + ast = varMerge(ast, defIdC, nDimsCOthers + i); + } + } else if (nDimsCBatch == 0) { + ast = varUnsqueeze(ast, defIdC, nDimsCOthers); + } + // clang-format off + ast = varSplit( + ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch); + ast = varSplit( + ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, -1, nWarpM); + ast = varSplit( + ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, 8, -1); + ast = varSplit( + ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, -1, nWarpN); + ast = varSplit( + ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, 8, -1); + ast = varSplit( + ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, 2, -1); + // clang-format on + + // Lower to CutlassMicroThread + LowerCutlassMicroBlock lowerCutlassMicroBlock{matMulId, nWarpBatch, nWarpM, + nWarpN}; + ast = lowerCutlassMicroBlock(ast); + + // Simplify the equivalent_ tree to help following passes + ast = shrinkFor(ast, matMulId, true, true); + + return ast; +} + +} // namespace freetensor diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc index 6dd570196..7d868b8f1 100644 --- a/src/schedule/parallelize_as.cc +++ b/src/schedule/parallelize_as.cc @@ -276,7 +276,15 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, ast = shrinkFor(ast, nest, true, unordered); for (auto &&[id, scope] : views::zip(adder.newScopeIds(), orderedScopes)) { - ast = parallelize(ast, id, scope->property_->parallel_, true); + try { + ast = parallelize(ast, id, scope->property_->parallel_, true); + } catch (const InvalidSchedule &e) { + throw InvalidSchedule( + FT_MSG << "Failed to create a new parallel scope " << id + << " in loop nest " << nest << " as " << scope->id() + << " in the reference loop nest " << reference << ": " + << e.what()); + } } return ast; diff --git a/src/schedule/var_merge.cc b/src/schedule/var_merge.cc index 42de2f006..93c89544c 100644 --- a/src/schedule/var_merge.cc +++ b/src/schedule/var_merge.cc @@ -1,3 +1,4 @@ +#include #include #include @@ -71,7 +72,7 @@ Stmt varMerge(const Stmt &_ast, const ID &def, int dim) { if (!mutator.found()) { throw InvalidSchedule(FT_MSG << def << " not found"); } - return ast; + return constFold(ast); } void Schedule::varMerge(const ID &def, int dim) { diff --git a/src/schedule/var_split.cc b/src/schedule/var_split.cc index 922f3bc7d..4dee7eaf2 100644 --- a/src/schedule/var_split.cc +++ b/src/schedule/var_split.cc @@ -1,3 +1,4 @@ +#include #include #include @@ -89,7 +90,7 @@ Stmt varSplit(const Stmt &_ast, const ID &def, int dim, VarSplitMode mode, if (!mutator.found()) { throw InvalidSchedule(FT_MSG << def << " not found"); } - return ast; + return constFold(ast); } void Schedule::varSplit(const ID &def, int dim, VarSplitMode mode, int factor, diff --git a/test/70.program/test_program_with_micro_kernel.py b/test/70.program/test_program_with_micro_kernel.py new file mode 100644 index 000000000..b0825e793 --- /dev/null +++ b/test/70.program/test_program_with_micro_kernel.py @@ -0,0 +1,96 @@ +import pytest + +import freetensor as ft + + +@pytest.mark.skipif(not ft.with_pytorch() or not ft.with_cuda(), + reason="requires PyTorch and CUDA") +def test_matmul_float64(): + + M = N = K = 5000 + block_n = block_m = 128 + block_k = 32 + n_warps = 4 + + device = ft.GPU() + target = device.target() + with target: + + @ft.transform + def matmul(a: ft.Var[(M, K), "float64"], b: ft.Var[(K, N), "float64"]): + c = ft.empty((M, N), "float64") + #! label: blk_m + for i in range(0, M, block_m): + #! label: blk_n + for j in range(0, N, block_n): + #! label: aa + aa = ft.empty((block_m, block_k), "float64") + #! label: bb + bb = ft.empty((block_k, block_n), "float64") + #! label: cc + cc = ft.empty((block_m, block_n), "float64") + #! label: zero_cc + for ii in range(block_m): + for jj in range(block_n): + cc[ii, jj] = 0 + for k in range(0, K, block_k): + #! label: load_aa + for ii in range(block_m): + for kk in range(block_k): + if i + ii < M and k + kk < K: + aa[ii, kk] = a[i + ii, k + kk] + else: + aa[ii, kk] = 0 + #! label: load_bb + for kk in range(block_k): + for jj in range(block_n): + if k + kk < K and j + jj < N: + bb[kk, jj] = b[k + kk, j + jj] + else: + bb[kk, jj] = 0 + #! label: micro_kernel + for ii in range(block_m): + for jj in range(block_n): + for kk in range(block_k): + cc[ii, jj] += aa[ii, kk] * bb[kk, jj] + #! label: flush_cc + for ii in range(block_m): + for jj in range(block_n): + # TODO: Can we avoid using `unbound`? + if ft.unbound(i + ii < M and j + jj < N): + c[i + ii, j + jj] = cc[ii, jj] + return c + + s = ft.Schedule(matmul, verbose=2) + s.parallelize("blk_m", "blockIdx.y") + s.parallelize("blk_n", "blockIdx.x") + s.as_matmul("micro_kernel", + target=target, + backend="cutlass-micro-block", + mode=ft.AsMatMulMode.TryVarReorder) + load_aa_warp, load_aa_thr = s.split( + s.split(s.merge("load_aa", "<-load_aa"), n_warps * 32)[1], 32) + s.parallelize(load_aa_warp, "threadIdx.y") + s.parallelize(load_aa_thr, "threadIdx.x") + load_bb_warp, load_bb_thr = s.split( + s.split(s.merge("load_bb", "<-load_bb"), n_warps * 32)[1], 32) + s.parallelize(load_bb_warp, "threadIdx.y") + s.parallelize(load_bb_thr, "threadIdx.x") + s.parallelize_as("zero_cc", "$as_matmul{micro_kernel}", "cc") + s.parallelize_as("flush_cc", "$as_matmul{micro_kernel}", "cc") + s.set_mem_type("aa", "gpu/shared") + s.set_mem_type("bb", "gpu/shared") + s.set_mem_type("cc", "gpu/local") + scheduled = s.func() + exe = ft.optimize(scheduled, verbose=2) + + import torch + + a_torch = torch.rand(M, K, dtype=torch.float64).cuda() + b_torch = torch.rand(K, N, dtype=torch.float64).cuda() + y_std = a_torch @ b_torch + a_arr = ft.array(a_torch) + b_arr = ft.array(b_torch) + y_arr = exe(a_arr, b_arr) + y_torch = y_arr.torch() + assert torch.all(torch.isclose(y_torch, y_std))