From 8783c41851cd3582490e04e69e0cd756a8c1db7f Mon Sep 17 00:00:00 2001 From: Vadim Markovtsev Date: Tue, 19 Sep 2023 01:58:19 +0200 Subject: [PATCH] Replace 0x1f with 0xffffffff in __shfl_sync (#1097) This fixes compatibility with H100 and resolves #1094 --- examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h | 2 +- examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py | 2 +- examples/45_dual_gemm/kernel/dual_gemm.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index a6d2a8a1..fcc484ea 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -653,7 +653,7 @@ struct B2bGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // Construct iterators to accumulator scale/bias vector diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py index 5fe51200..a640fc60 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py @@ -331,7 +331,7 @@ class gen_Kernel: operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n" - operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n" + operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n" operator_code += " " + "int lane_idx = threadIdx.x % 32;\n" for i in range (self.b2bnum - 1): diff --git a/examples/45_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h index f0ad97db..bd3c438f 100644 --- a/examples/45_dual_gemm/kernel/dual_gemm.h +++ b/examples/45_dual_gemm/kernel/dual_gemm.h @@ -364,7 +364,7 @@ struct DualGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; //