diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 8b2a314..b5084c2 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const float p_dropout, const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } @@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); @@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool deterministic, c10::optional gen_, @@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); @@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool deterministic, c10::optional gen_, @@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); @@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits @@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he num_heads = num_heads_k; } + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);