Implement fwd for head dim 128
This commit is contained in:
parent
0a398dfc37
commit
0d854692c6
@ -118,6 +118,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
@ -144,7 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
|
||||
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
|
||||
|
||||
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
|
||||
int base_N = (head_size == 128 || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
|
||||
int base_N = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
|
||||
// int base_N = 256;
|
||||
int seq_len = 512;
|
||||
if( max_seq_len <= 128 ) {
|
||||
|
||||
@ -1054,6 +1054,14 @@ struct Smem_tile_o {
|
||||
constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
|
||||
int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;
|
||||
|
||||
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
|
||||
// }
|
||||
|
||||
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
|
||||
// printf("threadIdx.x = %d\n", threadIdx.x);
|
||||
// }
|
||||
|
||||
// Assemble the write pointer.
|
||||
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
|
||||
|
||||
@ -1062,9 +1070,15 @@ struct Smem_tile_o {
|
||||
int read_col = tidx % THREADS_PER_ROW;
|
||||
|
||||
// Take the XOR pattern into account for the column.
|
||||
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
|
||||
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
|
||||
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
|
||||
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
|
||||
|
||||
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
|
||||
// }
|
||||
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
|
||||
// printf("threadIdx.x = %d\n", threadIdx.x);
|
||||
// }
|
||||
// Assemble the read pointer.
|
||||
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
|
||||
|
||||
@ -1085,16 +1099,31 @@ struct Smem_tile_o {
|
||||
#pragma unroll
|
||||
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
|
||||
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
|
||||
uint32_t smem_read = this->smem_read_ + imm;
|
||||
// TD [2022-06-05] Ugly fix for d=128, maybe there's a better way.
|
||||
if ((Cta_tile::N == 128) && (ii % 2 == 1)) {
|
||||
smem_read ^= 8 * BYTES_PER_LDS;
|
||||
}
|
||||
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("imm diff = %d\n", smem_read - this->smem_read_);
|
||||
// }
|
||||
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
|
||||
fmha::lds(tmp[jj], this->smem_read_ + imm);
|
||||
// fmha::lds(tmp[jj], this->smem_read_ + imm);
|
||||
fmha::lds(tmp[jj], smem_read);
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the reduction.
|
||||
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
|
||||
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
|
||||
// }
|
||||
#pragma unroll
|
||||
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
|
||||
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
|
||||
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1102,6 +1131,7 @@ struct Smem_tile_o {
|
||||
// Store the accumulators.
|
||||
template <int M, int N>
|
||||
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
|
||||
// uint32_t smem_write_og = this->smem_write_;
|
||||
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
|
||||
@ -1126,7 +1156,15 @@ struct Smem_tile_o {
|
||||
fmha::sts(this->smem_write_ + row_0, tmp0);
|
||||
fmha::sts(this->smem_write_ + row_1, tmp1);
|
||||
}
|
||||
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
|
||||
// }
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// uint4 read_tmp;
|
||||
// fmha::lds(read_tmp, this->smem_read_);
|
||||
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
|
||||
// }
|
||||
// Swizzle the write pointer using a XOR of 16B.
|
||||
this->smem_write_ ^= 32;
|
||||
|
||||
@ -1148,8 +1186,25 @@ struct Smem_tile_o {
|
||||
fmha::sts(this->smem_write_ + row_1, tmp1);
|
||||
}
|
||||
|
||||
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
|
||||
// }
|
||||
|
||||
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
|
||||
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
|
||||
static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
|
||||
if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
|
||||
this->smem_write_ ^= 15 * 32;
|
||||
} else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
|
||||
this->smem_write_ ^= 7 * 32;
|
||||
} else if( Mma_tile::MMAS_N >= 2 ) {
|
||||
this->smem_write_ ^= 3 * 32;
|
||||
}
|
||||
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// uint4 read_tmp;
|
||||
// fmha::lds(read_tmp, this->smem_read_);
|
||||
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
|
||||
// }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -121,8 +121,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
|
||||
}
|
||||
}
|
||||
} else if (launch_params.params.d == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
if( launch_params.params.s == 128 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
|
||||
// TD [2022-06-05] Keep K in registers to reduce register spilling
|
||||
// Gives about 6% speedup compared to using block size 128.
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else { // Need to use the same block size as backward
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (launch_params.params.d == 64) {
|
||||
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
@ -151,4 +164,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (launch_params.params.d == 128) {
|
||||
// if( launch_params.params.s == 128 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// } else {
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
|
||||
// // TD [2022-06-05] Keep K in registers to reduce register spilling
|
||||
// // Gives about 6% speedup compared to using block size 128.
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// } else { // Need to use the same block size as backward
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
@ -498,10 +498,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
|
||||
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
|
||||
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
|
||||
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
|
||||
// }
|
||||
}
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the O-reduction.
|
||||
// So we recalculate the max.
|
||||
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
|
||||
// }
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the
|
||||
// O-reduction. So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
// TODO: not sure if this is right for seqlen 128 or 256
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
|
||||
Loading…
Reference in New Issue
Block a user