Fixing argument checking when using seqlenq_ngroups_swapped. (#976)
When user send `out` as a parameter of the function `seqlenq_ngroups_swapped` with parameters that trigger, the CHECK_SHAPE is incorrect (since q shape is modified.)
This commit is contained in:
parent
ab59ec3590
commit
5bf201966a
@ -637,7 +637,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user