[Kernel] Fix marlin divide-by-zero warnings (#6904)
This commit is contained in:
parent
4fbf4aa128
commit
61a97c32f6
@ -1128,44 +1128,47 @@ __global__ void Marlin(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||||
if constexpr (!has_zp) {
|
if constexpr (has_zp) {
|
||||||
return;
|
// This code does not handle group_blocks == 0,
|
||||||
}
|
// which signifies act_order.
|
||||||
|
// has_zp implies AWQ, which doesn't have act_order,
|
||||||
|
static_assert(group_blocks != 0);
|
||||||
|
|
||||||
int pipe = full_pipe % stages;
|
int pipe = full_pipe % stages;
|
||||||
|
|
||||||
if constexpr (group_blocks == -1) {
|
if constexpr (group_blocks == -1) {
|
||||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
int4* sh_zp_stage =
|
int4* sh_zp_stage =
|
||||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
frag_qzp[k % 2][i] =
|
frag_qzp[k % 2][i] =
|
||||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
int warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
|
|
||||||
int cur_k = warp_row * 16;
|
int cur_k = warp_row * 16;
|
||||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||||
|
|
||||||
int k_blocks = cur_k / 16;
|
int k_blocks = cur_k / 16;
|
||||||
int cur_group_id = k_blocks / group_blocks;
|
int cur_group_id = k_blocks / group_blocks;
|
||||||
|
|
||||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||||
|
|
||||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||||
|
|
||||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
frag_qzp[k % 2][i] =
|
frag_qzp[k % 2][i] =
|
||||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -452,10 +452,15 @@ __global__ void Marlin(
|
|||||||
B_ptr[i] += b_gl_rd_delta_o;
|
B_ptr[i] += b_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if constexpr (group_blocks != -1) {
|
||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
// This assumes group_blocks >= thread_k_blocks
|
||||||
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
// and would need to be modified to support smaller groups.
|
||||||
s_gl_rd += s_gl_rd_delta;
|
static_assert(group_blocks >= thread_k_blocks);
|
||||||
|
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
|
s_gl_rd += s_gl_rd_delta;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||||
@ -480,7 +485,10 @@ __global__ void Marlin(
|
|||||||
// however, this does not seem to be a significant bottleneck, while some
|
// however, this does not seem to be a significant bottleneck, while some
|
||||||
// theoretically better attempts have lead to bad instruction ordering by
|
// theoretically better attempts have lead to bad instruction ordering by
|
||||||
// the compiler and correspondingly a noticeable drop in performance.
|
// the compiler and correspondingly a noticeable drop in performance.
|
||||||
if (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
|
// This assumes group_blocks >= thread_k_blocks
|
||||||
|
// and would need to be modified to support smaller groups.
|
||||||
|
static_assert(group_blocks >= thread_k_blocks);
|
||||||
int4* sh_s_stage =
|
int4* sh_s_stage =
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
|||||||
@ -404,10 +404,15 @@ __global__ void Marlin_24(
|
|||||||
meta_ptr[i] += m_gl_rd_delta_o;
|
meta_ptr[i] += m_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
// Only fetch scales if this tile starts a new group
|
// Only fetch scales if this tile starts a new group
|
||||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
if constexpr (group_blocks != -1) {
|
||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
// This assumes group_blocks >= thread_k_blocks
|
||||||
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
// and would need to be modified to support smaller groups.
|
||||||
s_gl_rd += s_gl_rd_delta;
|
static_assert(group_blocks >= thread_k_blocks);
|
||||||
|
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
|
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||||
|
s_gl_rd += s_gl_rd_delta;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||||
@ -432,7 +437,10 @@ __global__ void Marlin_24(
|
|||||||
// however, this does not seem to be a significant bottleneck, while some
|
// however, this does not seem to be a significant bottleneck, while some
|
||||||
// theoretically better attempts have lead to bad instruction ordering by
|
// theoretically better attempts have lead to bad instruction ordering by
|
||||||
// the compiler and correspondingly a noticeable drop in performance.
|
// the compiler and correspondingly a noticeable drop in performance.
|
||||||
if (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
|
// This assumes group_blocks >= thread_k_blocks
|
||||||
|
// and would need to be modified to support smaller groups.
|
||||||
|
static_assert(group_blocks >= thread_k_blocks);
|
||||||
int4* sh_s_stage =
|
int4* sh_s_stage =
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user