Replace 0x1f with 0xffffffff in __shfl_sync (#1097)
This fixes compatibility with H100 and resolves #1094
This commit is contained in:
parent
6407bcdf0a
commit
8783c41851
@ -653,7 +653,7 @@ struct B2bGemm {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
|
@ -331,7 +331,7 @@ class gen_Kernel:
|
||||
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
|
||||
|
||||
|
||||
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
|
||||
operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
|
||||
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
|
||||
|
||||
for i in range (self.b2bnum - 1):
|
||||
|
@ -364,7 +364,7 @@ struct DualGemm {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user