Do P * dP (pointwise) in the bwd in fp32 instead of fp16
This commit is contained in:
parent
6c3a8c65af
commit
a5559a0e75
@ -96,7 +96,7 @@ struct FMHA_fprop_params : public Qkv_params {
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded;
|
||||
int b, seqlen_q, seqlen_k, d;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_bmm1f;
|
||||
|
||||
@ -389,6 +389,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
}
|
||||
}
|
||||
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * ((!Is_dropout) || p >= 0.f ? dp : d);
|
||||
};
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the fragments for K^T.
|
||||
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
smem_kt.load(frag_kt[0], 0);
|
||||
@ -404,46 +422,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
}
|
||||
}
|
||||
|
||||
softmax.unpack_noscale(acc_dp);
|
||||
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
|
||||
// // will be zero.
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
|
||||
|
||||
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
softmax.pack(frag_dp);
|
||||
|
||||
if (!Is_dropout) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
|
||||
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
|
||||
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
|
||||
}
|
||||
const __half zero_h = __half(0.f);
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii);
|
||||
const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii));
|
||||
// If this element is dropped, then frag_p stores -p instead of p.
|
||||
// So pd holds -p * dp_sum in that case.
|
||||
const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
|
||||
const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd);
|
||||
const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd);
|
||||
frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
softmax.pack(frag_p);
|
||||
|
||||
// Store dp to smem for transpose
|
||||
smem_dp.store(frag_p);
|
||||
|
||||
@ -215,8 +215,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
|
||||
Loading…
Reference in New Issue
Block a user