From 0d854692c60249ed82bcddd1859bd9326eb3eaf1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 5 Jun 2022 22:30:09 -0700 Subject: [PATCH] Implement fwd for head dim 128 --- csrc/flash_attn/fmha_api.cpp | 3 +- csrc/flash_attn/src/fmha/smem_tile.h | 63 +++++++++++++++++-- .../src/fmha_fprop_fp16_kernel.sm80.cu | 34 +++++++++- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 13 +++- 4 files changed, 104 insertions(+), 9 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index ddb97d8..7f30238 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -118,6 +118,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75); auto stream = at::cuda::getCurrentCUDAStream().stream(); bool is_dropout = p_dropout > 0.0; @@ -144,7 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); - int base_N = (head_size == 128 || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; + int base_N = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; // int base_N = 256; int seq_len = 512; if( max_seq_len <= 128 ) { diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h index 5e67a34..ad6930c 100644 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ b/csrc/flash_attn/src/fmha/smem_tile.h @@ -1054,6 +1054,14 @@ struct Smem_tile_o { constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS; int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP; + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("write_row = %d, write_col = %d\n", write_row, write_col); + // } + + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } + // Assemble the write pointer. smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; @@ -1062,9 +1070,15 @@ struct Smem_tile_o { int read_col = tidx % THREADS_PER_ROW; // Take the XOR pattern into account for the column. - // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8))); - read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8)))); + read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8))); + // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8)))); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("read_row = %d, read_col = %d\n", read_row, read_col); + // } + // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) { + // printf("threadIdx.x = %d\n", threadIdx.x); + // } // Assemble the read pointer. this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; @@ -1085,16 +1099,31 @@ struct Smem_tile_o { #pragma unroll for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) { int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + uint32_t smem_read = this->smem_read_ + imm; + // TD [2022-06-05] Ugly fix for d=128, maybe there's a better way. + if ((Cta_tile::N == 128) && (ii % 2 == 1)) { + smem_read ^= 8 * BYTES_PER_LDS; + } + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("imm diff = %d\n", smem_read - this->smem_read_); + // } if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) { - fmha::lds(tmp[jj], this->smem_read_ + imm); + // fmha::lds(tmp[jj], this->smem_read_ + imm); + fmha::lds(tmp[jj], smem_read); } } // Perform the reduction. out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("out reduction: out = %.6f\n", reinterpret_cast(out[ii])[0]); + // } #pragma unroll for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) { out[ii] = fmha::fadd4(out[ii], tmp[jj]); + // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast(tmp[jj])[0], reinterpret_cast(out[ii])[0]); + // } } } } @@ -1102,6 +1131,7 @@ struct Smem_tile_o { // Store the accumulators. template inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { + // uint32_t smem_write_og = this->smem_write_; static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA; #pragma unroll for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { @@ -1126,7 +1156,15 @@ struct Smem_tile_o { fmha::sts(this->smem_write_ + row_0, tmp0); fmha::sts(this->smem_write_ + row_1, tmp1); } + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); + // } + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } // Swizzle the write pointer using a XOR of 16B. this->smem_write_ ^= 32; @@ -1148,8 +1186,25 @@ struct Smem_tile_o { fmha::sts(this->smem_write_ + row_1, tmp1); } + // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); + // } + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. - this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; + static_assert(Mma_tile::MMAS_N <= 8, "Not implemented"); + if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) { + this->smem_write_ ^= 15 * 32; + } else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) { + this->smem_write_ ^= 7 * 32; + } else if( Mma_tile::MMAS_N >= 2 ) { + this->smem_write_ ^= 3 * 32; + } + // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // uint4 read_tmp; + // fmha::lds(read_tmp, this->smem_read_); + // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); + // } } } }; diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 5cafc30..1281647 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -121,8 +121,21 @@ void run_fmha_fp16_sm80(Launch_params &l } } } else if (launch_params.params.d == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + if( launch_params.params.s == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) { + // TD [2022-06-05] Keep K in registers to reduce register spilling + // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; @@ -151,4 +164,21 @@ void run_fmha_fp16_sm80(Launch_params &l // } // } // } + // if (launch_params.params.d == 128) { + // if( launch_params.params.s == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) { + // // TD [2022-06-05] Keep K in registers to reduce register spilling + // // Gives about 6% speedup compared to using block size 128. + // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } + // } } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 93aef12..0ba609a 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -498,10 +498,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); + // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); + // printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0)); + // } } - // The mapping from tidx to rows changes between the softmax and the O-reduction. - // So we recalculate the max. + // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0)); + // } + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; // TODO: not sure if this is right for seqlen 128 or 256 int rows[Gmem_tile_o::STGS_PER_LOOP];