From 37a8f9e598f812cf0e86fd61bd91fb31c6a95292 Mon Sep 17 00:00:00 2001 From: akerr Date: Fri, 25 Sep 2020 10:34:46 -0700 Subject: [PATCH] CUTLASS 2.3.0 final. --- .../threadblock/predicated_tile_iterator.h | 16 ++++++++-------- include/cutlass/gemm/kernel/gemm.h | 2 +- include/cutlass/gemm/kernel/gemm_array.h | 2 +- include/cutlass/gemm/kernel/gemm_batched.h | 2 +- .../cutlass/gemm/kernel/gemm_planar_complex.h | 2 +- .../gemm/kernel/gemm_planar_complex_array.h | 2 +- include/cutlass/gemm/kernel/gemm_universal.h | 2 +- include/cutlass/gemm/kernel/sparse_gemm.h | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index e0a411e1..05af759a 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -372,12 +372,11 @@ public: bool guard = row_guard && mask_.predicates[column]; - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + - column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / - kElementsPerAccess], - guard); + if (guard) { + + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } } if (row + 1 < ThreadMap::Iterations::kRow) { @@ -691,8 +690,9 @@ public: bool guard = col_guard && mask_.predicates[iteration_contiguous_]; - cutlass::arch::global_store( - *frag_ptr, (void *)memory_pointer, guard); + if (guard) { + *memory_pointer = *frag_ptr; + } } /// Overrides the internal iteration index diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index c3aa6f8f..fc2daa97 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -224,7 +224,7 @@ struct Gemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h index 8cf25fb7..1c59a53a 100644 --- a/include/cutlass/gemm/kernel/gemm_array.h +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -184,7 +184,7 @@ struct GemmArray { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index ac8f5a37..45ec7756 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -196,7 +196,7 @@ struct GemmBatched { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index ab888940..aede20da 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -512,7 +512,7 @@ public: // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 0023bd58..e7fa89dc 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -441,7 +441,7 @@ public: // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index e6e3c97b..99ece267 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -402,7 +402,7 @@ public: // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index 85e3839c..7db469e5 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -269,7 +269,7 @@ struct SparseGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; //