From 2423cca3ad20c98ad452c7944feca3c222b88668 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Sat, 10 Feb 2024 04:01:27 -0500 Subject: [PATCH] fix backward for when query and key have different contiguity (#818) --- csrc/flash_attn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 79284dc..80dd5a3 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -830,7 +830,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { - dv = torch::empty_like(k); + dv = torch::empty_like(v); } at::Tensor dout_padded;