Replace 0x1f with 0xffffffff in __shfl_sync (#1097)

This fixes compatibility with H100 and resolves #1094
This commit is contained in:
Vadim Markovtsev 2023-09-19 01:58:19 +02:00 committed by GitHub
parent 6407bcdf0a
commit 8783c41851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -653,7 +653,7 @@ struct B2bGemm {
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
// Construct iterators to accumulator scale/bias vector

View File

@ -331,7 +331,7 @@ class gen_Kernel:
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
for i in range (self.b2bnum - 1):

View File

@ -364,7 +364,7 @@ struct DualGemm {
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//