Don't dispatch to local if window size >= seqlen_k

This commit is contained in:
Tri Dao 2023-12-23 20:59:26 -08:00
parent 732654583c
commit 0842ec0da4

View File

@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout,
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; }
@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
const int window_size_left,
int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_left,
int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
num_heads = num_heads_k;
}
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);