[Bug Fix] Illegal memory access, FP8 Llama 3.1 405b (#6852)
This commit is contained in:
parent
981b0d5673
commit
55712941e5
@ -328,20 +328,36 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
return EmptyProducerLoadCallbacks{};
|
return EmptyProducerLoadCallbacks{};
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class GTensor, class RTensor>
|
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
ConsumerStoreCallbacks(
|
||||||
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
GTensor&& tCgCol,
|
||||||
|
RTensor&& tCrCol,
|
||||||
|
CTensor&& tCcCol,
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
Params const& params
|
||||||
|
):
|
||||||
|
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||||
|
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||||
|
m(get<0>(problem_shape)),
|
||||||
params(params) {}
|
params(params) {}
|
||||||
|
|
||||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
RTensor tCrCol;
|
||||||
|
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
Params const& params;
|
Params const& params;
|
||||||
|
int m;
|
||||||
|
|
||||||
CUTLASS_DEVICE void
|
CUTLASS_DEVICE void
|
||||||
begin() {
|
begin() {
|
||||||
|
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(pred); ++i) {
|
||||||
|
pred(i) = get<0>(tCcCol(i)) < m;
|
||||||
|
}
|
||||||
|
|
||||||
if (!params.col_broadcast) {
|
if (!params.col_broadcast) {
|
||||||
fill(tCrCol, *(params.ptr_col));
|
fill(tCrCol, *(params.ptr_col));
|
||||||
return;
|
return;
|
||||||
@ -349,7 +365,7 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
|
|
||||||
// Filter so we don't issue redundant copies over stride-0 modes
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
// (only works if 0-strides are in same location, which is by construction)
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
copy_aligned(filter(tCgCol), filter(tCrCol));
|
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ElementAccumulator, int FragmentSize>
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
@ -381,8 +397,20 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
// Generate an identity tensor matching the shape of the global tensor and
|
||||||
cute::move(tCgCol), cute::move(tCrCol), params);
|
// partition the same way, this will be used to generate the predicate
|
||||||
|
// tensor for loading
|
||||||
|
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||||
|
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
|
||||||
|
return ConsumerStoreCallbacks(
|
||||||
|
cute::move(tCgCol),
|
||||||
|
cute::move(tCrCol),
|
||||||
|
cute::move(tCcCol),
|
||||||
|
args.problem_shape_mnkl,
|
||||||
|
params
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user