diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 3468e4b..d1a9513 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int{}, Int{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 8183882..eddd8f8 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; ////////////////////////////////////////////////////////////////////////////////////////////////////