diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 0539a60..da95634 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -216,7 +216,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(k.stride(-1) == 1); TORCH_CHECK(v.stride(-1) == 1); TORCH_CHECK(out.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); + TORCH_CHECK(cu_seqlens_q.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); const auto sizes = q.sizes();