Do P * dP (pointwise) in the bwd in fp32 instead of fp16

This commit is contained in:
Tri Dao 2022-07-03 17:52:05 -07:00
parent 6c3a8c65af
commit a5559a0e75
3 changed files with 22 additions and 43 deletions

View File

@ -96,7 +96,7 @@ struct FMHA_fprop_params : public Qkv_params {
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded;
int b, seqlen_q, seqlen_k, d;
// The scaling factors for the kernel.
float scale_bmm1f;

View File

@ -389,6 +389,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
auto pointwise_mult = [](float p, float dp, float d) {
return p * ((!Is_dropout) || p >= 0.f ? dp : d);
};
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) {
#pragma unroll
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]);
softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]);
softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]);
softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]);
softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]);
softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]);
softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]);
softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]);
}
}
// Load the fragments for K^T.
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0);
@ -404,46 +422,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
softmax.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp);
if (!Is_dropout) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
}
}
} else {
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
}
const __half zero_h = __half(0.f);
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii);
const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd);
const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd);
frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high);
}
}
}
}
softmax.pack(frag_p);
// Store dp to smem for transpose
smem_dp.store(frag_p);

View File

@ -215,8 +215,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths