Support Turing mma instructions

This commit is contained in:
Tri Dao 2022-06-02 16:29:54 -07:00
parent 050873327e
commit 2712aa4c8d
5 changed files with 82 additions and 18 deletions

View File

@ -117,7 +117,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0); bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0; bool is_dropout = p_dropout > 0.0;
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax); Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
@ -143,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = head_size == 128 ? 128 : 256; int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256;
// int base_N = 256; // int base_N = 256;
int seq_len = 512; int seq_len = 512;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
@ -236,7 +237,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0); bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
auto launch = &run_fmha_dgrad_fp16_sm80; auto launch = &run_fmha_dgrad_fp16_sm80;
bool is_dropout = p_dropout > 0.0; bool is_dropout = p_dropout > 0.0;
@ -268,7 +270,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = head_size == 128 ? 128 : 256; int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256;
int seq_len = 512; int seq_len = 512;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
seq_len = 128; seq_len = 128;

View File

@ -257,7 +257,15 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
template<typename Acc, typename A, typename B, int M, int N> template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert(0);
#endif
using Element = cutlass::half_t; using Element = cutlass::half_t;
using ElementC = float; using ElementC = float;
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
@ -267,19 +275,65 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
using FragmentA = typename WarpMma::FragmentA; constexpr int kIters = Shape::kK / InstructionShape::kK;
using FragmentB = typename WarpMma::FragmentB; // using FragmentA = typename WarpMma::FragmentA;
// using FragmentB = typename WarpMma::FragmentB;
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
using FragmentC = typename WarpMma::FragmentC; using FragmentC = typename WarpMma::FragmentC;
static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); // printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
// }
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a); // const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b); // const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc); FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
FragmentA a_cl[kIters][M];
FragmentA b_cl[kIters][N];
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int mi = 0; mi < M; mi++) {
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
}
}
}
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
// TD [2022-06-02] For some reason the order for frag_b is different.
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
}
}
}
WarpMma mma_op; WarpMma mma_op;
mma_op(c_cl, a_cl, b_cl, c_cl); // mma_op(c_cl, a_cl, b_cl, c_cl);
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
}
// The modified c_cl is not copied back into acc, idk why // The modified c_cl is not copied back into acc, idk why
#pragma unroll #pragma unroll

View File

@ -88,6 +88,9 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
} else if (dprops->major == 8 && dprops->minor > 0) { } else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
} }
} else if (params.d == 128) { } else if (params.d == 128) {

View File

@ -105,12 +105,15 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
if( launch_params.params.s == 128 ) { if( launch_params.params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) { } else if( launch_params.params.s >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; auto dprops = at::cuda::getCurrentDeviceProperties();
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); if (dprops->major == 8 && dprops->minor >= 0) {
} else { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); } else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;

View File

@ -107,7 +107,9 @@ raise_if_cuda_home_none("flash_attn")
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11: if int(bare_metal_major) < 11:
raise RuntimeError("--flashattn only supported on SM80+") raise RuntimeError("FlashAttention is only supported on CUDA 11")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")