From 468793641374d00a0cd7017cc0b2310b4c710376 Mon Sep 17 00:00:00 2001 From: Grigory Sizov Date: Thu, 8 Feb 2024 02:41:53 +0100 Subject: [PATCH] Fix Windows build (#816) --- csrc/flash_attn/flash_api.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 84cb71f..79284dc 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -696,8 +696,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } if (seqlenq_ngroups_swapped) { - long size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; - long size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; + int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);