From d9a5cb291c733a31a968ac145f622b687ddd1c66 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 10 Feb 2024 01:03:00 -0800 Subject: [PATCH] Fix dv = torch::empty_like(k) for mha_bwd_varlen as well --- csrc/flash_attn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 80dd5a3..2288850 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { - dv = torch::empty_like(k); + dv = torch::empty_like(v); } at::Tensor dout_padded;