-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Invoke micro kernels from CUTLASS (#596)
* [WIP] Invoke micro kernels from CUTLASS * Fix: accum should be in row-major * Multiple fixes and workarounds Fixes: - Check `beta=0` or `beta=1` in micro kernel. - Fix the testing program. Workarounds: - Temporarily add __syncthreads in micro kernels. - Temporarily set data types to float64 to avoid the issue that CUTLASS does not support RowMajor layout for float16. * Fix __syncthreads() * Adjust test case name * Fix compiler warnings * Include micro kernels as needed * Fix a memory bug in pass/simplify
- Loading branch information
Showing
24 changed files
with
1,001 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#ifndef CUTLASS_MICRO_KERNEL_PROPERTY_H | ||
#define CUTLASS_MICRO_KERNEL_PROPERTY_H | ||
|
||
#include <expr.h> | ||
#include <sub_tree.h> | ||
|
||
namespace freetensor { | ||
|
||
struct CutlassMicroKernelProperty : public ASTPart { | ||
int nWarpBatch_, nWarpM_, nWarpN_; | ||
Expr warpIdBatch_, warpIdM_, warpIdN_, laneId_; | ||
|
||
template <typename TwarpIdBatch, typename TwarpIdM, typename TwarpIdN, | ||
typename TlaneId> | ||
CutlassMicroKernelProperty(int nWarpBatch, int nWarpM, int nWarpN, | ||
TwarpIdBatch &&warpIdBatch, TwarpIdM &&warpIdM, | ||
TwarpIdN &&warpIdN, TlaneId &&laneId) | ||
: nWarpBatch_(nWarpBatch), nWarpM_(nWarpM), nWarpN_(nWarpN), | ||
warpIdBatch_(std::forward<TwarpIdM>(warpIdBatch)), | ||
warpIdM_(std::forward<TwarpIdM>(warpIdM)), | ||
warpIdN_(std::forward<TwarpIdN>(warpIdN)), | ||
laneId_(std::forward<TlaneId>(laneId)) {} | ||
|
||
void compHash() override; | ||
}; | ||
|
||
inline Ref<CutlassMicroKernelProperty> | ||
deepCopy(const Ref<CutlassMicroKernelProperty> &p) { | ||
return Ref<CutlassMicroKernelProperty>::make( | ||
p->nWarpBatch_, p->nWarpM_, p->nWarpN_, deepCopy(p->warpIdBatch_), | ||
deepCopy(p->warpIdM_), deepCopy(p->warpIdN_), deepCopy(p->laneId_)); | ||
} | ||
|
||
} // namespace freetensor | ||
|
||
#endif // CUTLASS_MICRO_KERNEL_PROPERTY_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H | ||
#define FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H | ||
|
||
#include <stmt.h> | ||
|
||
namespace freetensor { | ||
|
||
Stmt lowerCutlassMicroBlock(const Stmt &ast, const ID &matMulId, | ||
const ID &defIdC, | ||
const std::vector<bool> &dimsCBatch, | ||
const std::vector<bool> &dimsCM, | ||
const std::vector<bool> &dimsCN); | ||
|
||
} | ||
|
||
#endif // FREE_TENSOR_LOWER_CUTLASS_MICRO_BLOCK_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H | ||
#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H | ||
|
||
#if (defined(__CUDA_ARCH__)) // Device code | ||
|
||
#if (__CUDA_ARCH__ >= 800) | ||
#include "gemm_sm80.h" | ||
#else | ||
#error "Unsupported architecture" | ||
#endif | ||
|
||
#else // Host code | ||
|
||
// Only declaration is needed | ||
template <int M, int N, int K, int num_warp_batch, int num_warp_m, | ||
int num_warp_n, bool trans_A, bool trans_B, typename A_type, | ||
typename B_type, typename C_type> | ||
__device__ void matmul_thread(const A_type *pA, const B_type *pB, C_type *accum, | ||
int lda, int ldb, int stridea, int strideb, | ||
int stridec, double alpha, double beta, | ||
int warp_id_batch, int warp_id_m, int warp_id_n, | ||
int lane_id); | ||
|
||
#endif | ||
|
||
#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_H |
Oops, something went wrong.