diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 80dd5a3..2288850 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { - dv = torch::empty_like(k); + dv = torch::empty_like(v); } at::Tensor dout_padded;