Fixing argument checking when using seqlenq_ngroups_swapped. (#976)

When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
This commit is contained in:
Nicolas Patry 2024-07-01 07:39:22 +02:00 committed by GitHub
parent ab59ec3590
commit 5bf201966a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -637,7 +637,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});