Fix dv = torch::empty_like(k) for mha_bwd_varlen as well
This commit is contained in:
parent
a190df011c
commit
d9a5cb291c
@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = torch::empty_like(k);
|
||||
dv = torch::empty_like(v);
|
||||
}
|
||||
|
||||
at::Tensor dout_padded;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user