Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Xetla GRF Mode Control (#317)
Browse files Browse the repository at this point in the history
* grf_mode ctrl

* Use 2d when tile_size_x/y=1
  • Loading branch information
DDEle authored Jul 19, 2024
1 parent d9dd484 commit 1081543
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
5 changes: 5 additions & 0 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ struct fpu_attr_t {
template <gpu_arch arch_tag>
inline constexpr bool arch_has_fpu = fpu_attr_t<arch_tag>::has_fpu;

#define GRF grf_mode::double_grf
#ifdef NORMAL_GRF
#define GRF grf_mode::normal_grf
#endif

template <grf_mode grf_num_mode>
struct register_nums_t {
static constexpr uint32_t register_nums =
Expand Down
7 changes: 5 additions & 2 deletions include/experimental/group/gemm/compute_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ struct compute_policy_int4_dequantize<
static constexpr bool is_col_major_b =
quant_info_.weight_mem_layout == mem_layout::col_major;

using reg_nums_t = register_nums_t<GRF>;
static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
static constexpr uint32_t block_bytes_x_a =
is_col_major_b ? reg_nums_t::register_nums : 32;
static constexpr uint32_t block_size_x_a =
block_bytes_x_a / sizeof(dtype_mma_a);
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
static constexpr uint32_t block_bytes_y_b =
is_col_major_b ? reg_nums_t::register_nums : 32;
static constexpr uint32_t block_size_y_b =
block_bytes_y_b / sizeof(dtype_mma_b);

Expand Down
12 changes: 4 additions & 8 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1887,10 +1887,8 @@ struct prefetch_payload_t<
arch_tag_,
std::enable_if_t<
(arch_tag_ == gpu_arch::XeHpc) &&
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
using dtype = dtype_;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down Expand Up @@ -2180,10 +2178,8 @@ struct prefetch_payload_t<
num_coop_sg_,
arch_tag_,
std::enable_if_t<
((block_size_y_ == 1 && tile_size_y_ == 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ == 1 && tile_size_x_ == 1) &&
mem_layout_ == mem_layout::col_major)>> {
((tile_size_y_ == 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> {
using dtype = dtype_;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down

0 comments on commit 1081543

Please sign in to comment.