Transpose out when swapping seqlen_q and num_groups
This commit is contained in:
parent
f692b98d80
commit
9eb3d099c1
@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
||||
params.num_splits = num_splits;
|
||||
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
||||
if (num_splits < 1) {
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
||||
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
|
||||
}
|
||||
if (params.num_splits > 1) {
|
||||
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
||||
seqlen_q = ngroups;
|
||||
num_heads = num_heads_k;
|
||||
@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
||||
}
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
} else {
|
||||
out = torch::empty_like(q_padded);
|
||||
@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
max_seqlen_q = ngroups;
|
||||
num_heads = num_heads_k;
|
||||
@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
}
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
} else {
|
||||
out = torch::empty_like(q_padded);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user