[Kernel] Fix marlin divide-by-zero warnings (#6904)

This commit is contained in:
Tyler Michael Smith 2024-07-29 21:26:07 -04:00 committed by GitHub
parent 4fbf4aa128
commit 61a97c32f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 39 deletions

View File

@ -1128,9 +1128,11 @@ __global__ void Marlin(
}; };
auto fetch_zp_to_registers = [&](int k, int full_pipe) { auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) { if constexpr (has_zp) {
return; // This code does not handle group_blocks == 0,
} // which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(group_blocks != 0);
int pipe = full_pipe % stages; int pipe = full_pipe % stages;
@ -1168,6 +1170,7 @@ __global__ void Marlin(
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
} }
} }
}
}; };
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.

View File

@ -452,12 +452,17 @@ __global__ void Marlin(
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
}
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point. // waiting is also correct at this point.
cp_async_fence(); cp_async_fence();
@ -480,7 +485,10 @@ __global__ void Marlin(
// however, this does not seem to be a significant bottleneck, while some // however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));

View File

@ -404,12 +404,17 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o; meta_ptr[i] += m_gl_rd_delta_o;
} }
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} }
}
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point. // waiting is also correct at this point.
cp_async_fence(); cp_async_fence();
@ -432,7 +437,10 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some // however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by // theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance. // the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) { if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage = int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks))); (pipe / (group_blocks / thread_k_blocks)));