Update cutlass to v3.4.0
This commit is contained in:
parent
395e5a0dba
commit
8f4d82cf5e
@ -1 +1 @@
|
||||
Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5
|
||||
Subproject commit 751eb9a8859ac36bfc77551f9e4a957c31a5a8b1
|
||||
@ -32,11 +32,12 @@ CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
// This gives the correct layout, idk why.
|
||||
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
|
||||
@ -45,7 +46,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
// Stride<_1, _64, _8> >{},
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
make_layout(Int<TileShape_K>{}));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
|
||||
}
|
||||
@ -59,13 +60,14 @@ CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
auto t = make_tile(make_layout(size<0>(TileShape_MNK{})),
|
||||
auto t = make_tile(make_layout(Int<TileShape_M>{}),
|
||||
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
|
||||
@ -90,8 +92,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
// constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value;
|
||||
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;
|
||||
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
|
||||
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
|
||||
|
||||
|
||||
@ -41,7 +41,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||
|
||||
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
|
||||
flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
|
||||
|
||||
@ -32,10 +32,8 @@ struct Flash_kernel_traits {
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
Tile<Int<16 * kNWarps>, _16, _16>>;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -197,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using TiledMmaSdP = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
|
||||
|
||||
using TiledMmadKV = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
|
||||
|
||||
using TiledMmadQ = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
|
||||
|
||||
using SmemLayoutAtomQdO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user