Implement fwd for head dim 128

This commit is contained in:
Tri Dao 2022-06-05 22:30:09 -07:00
parent 0a398dfc37
commit 0d854692c6
4 changed files with 104 additions and 9 deletions

View File

@ -118,6 +118,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
@ -144,7 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = (head_size == 128 || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
int base_N = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
// int base_N = 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {

View File

@ -1054,6 +1054,14 @@ struct Smem_tile_o {
constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
@ -1062,9 +1070,15 @@ struct Smem_tile_o {
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
@ -1085,16 +1099,31 @@ struct Smem_tile_o {
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
uint32_t smem_read = this->smem_read_ + imm;
// TD [2022-06-05] Ugly fix for d=128, maybe there's a better way.
if ((Cta_tile::N == 128) && (ii % 2 == 1)) {
smem_read ^= 8 * BYTES_PER_LDS;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha::lds(tmp[jj], smem_read);
}
}
// Perform the reduction.
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}
@ -1102,6 +1131,7 @@ struct Smem_tile_o {
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
// uint32_t smem_write_og = this->smem_write_;
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
@ -1126,7 +1156,15 @@ struct Smem_tile_o {
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
@ -1148,8 +1186,25 @@ struct Smem_tile_o {
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
this->smem_write_ ^= 15 * 32;
} else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}
};

View File

@ -121,8 +121,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
}
}
} else if (launch_params.params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
if( launch_params.params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
@ -151,4 +164,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
}

View File

@ -498,10 +498,19 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
// }
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
// }
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
// TODO: not sure if this is right for seqlen 128 or 256
int rows[Gmem_tile_o::STGS_PER_LOOP];