Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
2 parents 54a3213 + 84cc74b commit 72fdc56
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,13 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
if constexpr (group_blocks != -1) {
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
Expand All @@ -429,7 +432,7 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down

0 comments on commit 72fdc56

Please sign in to comment.