commit
c2b80ad4e4
@ -372,12 +372,11 @@ public:
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn /
|
||||
kElementsPerAccess],
|
||||
guard);
|
||||
if (guard) {
|
||||
|
||||
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
@ -691,8 +690,9 @@ public:
|
||||
|
||||
bool guard = col_guard && mask_.predicates[iteration_contiguous_];
|
||||
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
*frag_ptr, (void *)memory_pointer, guard);
|
||||
if (guard) {
|
||||
*memory_pointer = *frag_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
|
@ -224,7 +224,7 @@ struct Gemm {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
|
@ -184,7 +184,7 @@ struct GemmArray {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
|
@ -196,7 +196,7 @@ struct GemmBatched {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
|
@ -512,7 +512,7 @@ public:
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
|
@ -441,7 +441,7 @@ public:
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
|
@ -402,7 +402,7 @@ public:
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
|
@ -269,7 +269,7 @@ struct SparseGemm {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user