Skip to content

Commit

Permalink
Add bf16 and int8 wmma gemms for Navi3x and Navi4x. (#1671)
Browse files Browse the repository at this point in the history
* add bf16 gemms for gfx11/gfx12

* reduce the input values in test_gemm

* add int8 wmma gemm instances for gfx11/gfx12

* add example gemm_wmma_int8

* fix bug in gemm_wmma_int8 test

* increase bf16 gemm test tolerance

* update the dates and clean-up commented-out instances
  • Loading branch information
illsilin authored Nov 18, 2024
1 parent 754adc7 commit 8aba272
Show file tree
Hide file tree
Showing 16 changed files with 896 additions and 26 deletions.
4 changes: 4 additions & 0 deletions example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16)
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8)
84 changes: 84 additions & 0 deletions example/01_gemm/gemm_wmma_bf16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"

using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::bhalf_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;

using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
84 changes: 84 additions & 0 deletions example/01_gemm/gemm_wmma_int8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"

using ADataType = int8_t;
using BDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = int8_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;

using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
11 changes: 6 additions & 5 deletions include/ck/utility/amd_wmma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif

#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif

/********************************WAVE32 MODE***********************************************/

// src: fp16, dst: fp32
Expand Down Expand Up @@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11__) || defined(__gfx12__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
Expand Down Expand Up @@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/

#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif

// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
Expand Down
52 changes: 52 additions & 0 deletions library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,58 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<CDataType, ck::bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_km_nk_mn_instances(op_ptrs);
}
}
#endif
#endif

#ifdef CK_USE_XDL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,46 @@ void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_int8_int8_int8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_int8_int8_int8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);

} // namespace instance
} // namespace device
} // namespace tensor_operation
Expand Down
2 changes: 1 addition & 1 deletion library/include/ck/library/utility/check_err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ typename std::enable_if<
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double rtol = 1e-1,
double atol = 1e-3)
{
if(out.size() != ref.size())
Expand Down
33 changes: 13 additions & 20 deletions library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ set(GEMM_INSTANCES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)

list(APPEND GEMM_INSTANCES
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
Expand All @@ -21,9 +19,6 @@ list(APPEND GEMM_INSTANCES
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
)

list(APPEND GEMM_INSTANCES
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
Expand Down Expand Up @@ -78,9 +73,6 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)

list(APPEND GEMM_INSTANCES
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
Expand All @@ -92,15 +84,11 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)

list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)

list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp
Expand All @@ -109,14 +97,19 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)


list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp)
device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_mk_kn_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_mk_nk_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_km_kn_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_km_nk_mn_instance.cpp)

add_instance_library(device_gemm_instance ${GEMM_INSTANCES})

Expand Down
Loading

0 comments on commit 8aba272

Please sign in to comment.