diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 3d608f3..8c329a8 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -117,7 +117,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot c10::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75); auto stream = at::cuda::getCurrentCUDAStream().stream(); bool is_dropout = p_dropout > 0.0; Launch_params launch_params(dprops, stream, is_dropout, return_softmax); @@ -143,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); - int base_N = head_size == 128 ? 128 : 256; + int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; // int base_N = 256; int seq_len = 512; if( max_seq_len <= 128 ) { @@ -236,7 +237,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size c10::optional gen_ ) { auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75); auto launch = &run_fmha_dgrad_fp16_sm80; bool is_dropout = p_dropout > 0.0; @@ -268,7 +270,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); - int base_N = head_size == 128 ? 128 : 256; + int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; int seq_len = 512; if( max_seq_len <= 128 ) { seq_len = 128; diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 9721458..22ef11a 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -257,7 +257,15 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) template inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; +#else + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + // TD [2022-06-02] We don't support Volta (SM70) yet. + assert(0); +#endif using Element = cutlass::half_t; using ElementC = float; using LayoutA = cutlass::layout::RowMajor; @@ -267,19 +275,65 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; - using FragmentA = typename WarpMma::FragmentA; - using FragmentB = typename WarpMma::FragmentB; + constexpr int kIters = Shape::kK / InstructionShape::kK; + // using FragmentA = typename WarpMma::FragmentA; + // using FragmentB = typename WarpMma::FragmentB; + using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA; + using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB; using FragmentC = typename WarpMma::FragmentC; - static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); - static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) { + // printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements); + // printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements); + // printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements); + // printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements); + // printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements); + // printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements); + // } + + // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); + // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); + static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS); + static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS); static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); - const FragmentA a_cl = reinterpret_cast(a); - const FragmentB b_cl = reinterpret_cast(b); + // const FragmentA a_cl = reinterpret_cast(a); + // const FragmentB b_cl = reinterpret_cast(b); FragmentC c_cl = reinterpret_cast(acc); + FragmentA a_cl[kIters][M]; + FragmentA b_cl[kIters][N]; + constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2; + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + #pragma unroll + for (int mi = 0; mi < M; mi++) { + uint32_t *a_ptr = a_cl[iter][mi].raw_data(); + #pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + a_ptr[ki] = a[mi].regs_[iter * kRegs + ki]; + } + } + } + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + #pragma unroll + for (int ni = 0; ni < N; ni++) { + uint32_t *b_ptr = b_cl[iter][ni].raw_data(); + #pragma unroll + for (int ki = 0; ki < kRegs; ki++) { + // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki]; + // TD [2022-06-02] For some reason the order for frag_b is different. + b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter]; + } + } + } WarpMma mma_op; - mma_op(c_cl, a_cl, b_cl, c_cl); + // mma_op(c_cl, a_cl, b_cl, c_cl); + #pragma unroll + for (int iter = 0; iter < kIters; iter++) { + mma_op(c_cl, reinterpret_cast(a_cl[iter]), + reinterpret_cast(b_cl[iter]), c_cl); + } // The modified c_cl is not copied back into acc, idk why #pragma unroll diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index e3e2cdc..778482f 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -88,6 +88,9 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ } else if (dprops->major == 8 && dprops->minor > 0) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if (dprops->major == 7 && dprops->minor == 5) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); } } } else if (params.d == 128) { diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 1f734c2..fcc3ecd 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -105,12 +105,15 @@ void run_fmha_fp16_sm80(Launch_params &l if( launch_params.params.s == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.s == 256 ) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.s >= 256 ) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } } } else if (launch_params.params.d == 128) { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; diff --git a/setup.py b/setup.py index a661fa6..68baf9c 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,9 @@ raise_if_cuda_home_none("flash_attn") cc_flag = [] _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) if int(bare_metal_major) < 11: - raise RuntimeError("--flashattn only supported on SM80+") + raise RuntimeError("FlashAttention is only supported on CUDA 11") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80")