-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add bf16 and int8 wmma gemms for Navi3x and Navi4x. (#1671)
* add bf16 gemms for gfx11/gfx12 * reduce the input values in test_gemm * add int8 wmma gemm instances for gfx11/gfx12 * add example gemm_wmma_int8 * fix bug in gemm_wmma_int8 test * increase bf16 gemm test tolerance * update the dates and clean-up commented-out instances
- Loading branch information
Showing
16 changed files
with
896 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#include "common.hpp" | ||
|
||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" | ||
|
||
using ADataType = ck::bhalf_t; | ||
using BDataType = ck::bhalf_t; | ||
using AccDataType = float; | ||
using CShuffleDataType = float; | ||
using CDataType = ck::bhalf_t; | ||
|
||
using ALayout = Row; | ||
using BLayout = Col; | ||
using CLayout = Row; | ||
|
||
using AElementOp = PassThrough; | ||
using BElementOp = PassThrough; | ||
using CElementOp = PassThrough; | ||
|
||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; | ||
|
||
// clang-format off | ||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle | ||
< ALayout, | ||
BLayout, | ||
CLayout, | ||
ADataType, | ||
BDataType, | ||
CDataType, | ||
AccDataType, | ||
CShuffleDataType, | ||
AElementOp, | ||
BElementOp, | ||
CElementOp, | ||
GemmDefault, | ||
1, // Prefetch stage | ||
128, // BlockSize | ||
64, // MPerBlock | ||
128, // NPerBlock | ||
64, // KPerBlock | ||
2, // K1 | ||
16, // MPerWmma | ||
16, // NPerWmma | ||
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave | ||
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave | ||
S<4, 32, 1>, | ||
S<1, 0, 2>, | ||
S<1, 0, 2>, | ||
2, | ||
2, | ||
2, | ||
true, | ||
S<4, 32, 1>, | ||
S<1, 0, 2>, | ||
S<1, 0, 2>, | ||
2, | ||
2, | ||
2, | ||
true, | ||
1, // C shuffle (M Repeat) Per store | ||
1, // C shuffle (N Repeat) Per store | ||
S<1, 32, 1, 4>, | ||
8>; | ||
// clang-format on | ||
|
||
using ReferenceGemmInstance = ck::tensor_operation::host:: | ||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; | ||
|
||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout, | ||
BLayout, | ||
CLayout, | ||
ADataType, | ||
BDataType, | ||
CDataType, | ||
AccDataType, | ||
AElementOp, | ||
BElementOp, | ||
CElementOp>; | ||
|
||
#include "run_gemm_example.inc" | ||
|
||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#include "common.hpp" | ||
|
||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" | ||
|
||
using ADataType = int8_t; | ||
using BDataType = int8_t; | ||
using AccDataType = int32_t; | ||
using CShuffleDataType = int32_t; | ||
using CDataType = int8_t; | ||
|
||
using ALayout = Row; | ||
using BLayout = Col; | ||
using CLayout = Row; | ||
|
||
using AElementOp = PassThrough; | ||
using BElementOp = PassThrough; | ||
using CElementOp = PassThrough; | ||
|
||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; | ||
|
||
// clang-format off | ||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle | ||
< ALayout, | ||
BLayout, | ||
CLayout, | ||
ADataType, | ||
BDataType, | ||
CDataType, | ||
AccDataType, | ||
CShuffleDataType, | ||
AElementOp, | ||
BElementOp, | ||
CElementOp, | ||
GemmDefault, | ||
1, // Prefetch stage | ||
128, // BlockSize | ||
64, // MPerBlock | ||
128, // NPerBlock | ||
64, // KPerBlock | ||
2, // K1 | ||
16, // MPerWmma | ||
16, // NPerWmma | ||
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave | ||
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave | ||
S<4, 32, 1>, | ||
S<1, 0, 2>, | ||
S<1, 0, 2>, | ||
2, | ||
2, | ||
2, | ||
true, | ||
S<4, 32, 1>, | ||
S<1, 0, 2>, | ||
S<1, 0, 2>, | ||
2, | ||
2, | ||
2, | ||
true, | ||
1, // C shuffle (M Repeat) Per store | ||
1, // C shuffle (N Repeat) Per store | ||
S<1, 32, 1, 4>, | ||
8>; | ||
// clang-format on | ||
|
||
using ReferenceGemmInstance = ck::tensor_operation::host:: | ||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; | ||
|
||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout, | ||
BLayout, | ||
CLayout, | ||
ADataType, | ||
BDataType, | ||
CDataType, | ||
AccDataType, | ||
AElementOp, | ||
BElementOp, | ||
CElementOp>; | ||
|
||
#include "run_gemm_example.inc" | ||
|
||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.