* Support ck in fmha * Add ck submodule * Do not return lse if return_softmax == false * Use receipt to speed up ck compile time * Integrate new version of ck_tile * Support dropout for mha_fwd() * Add dropout to mha_varlen_fwd() * Update ck to develop * Extract padding function for dropout randval * Extract randval transformation function * Sync the code structure and coding style with FA * Remove this line, c++ api will handle this. Sync with test_flash_attn.py * fix compile error * Add mha_bwd * Generate dropout seed and offset from user generator * update CK * Add mha_varlen_bwd * Use same python as build flash-attn to generate ck kernel * Fix bug of group mode fwd about returning softmax lse * larger the test tollerance * Add test_flash_attn_output() and test_flash_attn_varlen_output() * Always fill softmax_lse * Remove duplicate benchmark script, since we already implement mha_bwd * Refine get value from tuple * Use default parameter for stream_config * unblock all platform * Add comment * refine the test code * Refine naming * Add unpack to namespace * Do not hardcode the warp size 64 * Add more targets * Add README * Optimize mha_fwd if seqlen_q == 1 * Support get_wheel_url for rocm * Detect rocm environment by pytorch's IS_HIP_EXTENSION * update to lastest ck * Add necessary compile flag * Sync the api with upstream FA --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Yichen Yan <oraluben@outlook.com>
100 lines
5.7 KiB
C++
100 lines
5.7 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#include "flash_common.hpp"
|
|
|
|
std::vector<at::Tensor>
|
|
mha_fwd(at::Tensor &q,
|
|
const at::Tensor &k,
|
|
const at::Tensor &v,
|
|
c10::optional<at::Tensor> &out_,
|
|
c10::optional<at::Tensor> &alibi_slopes_,
|
|
const float p_dropout,
|
|
const float softmax_scale,
|
|
bool is_causal,
|
|
int window_size_left,
|
|
int window_size_right,
|
|
const float softcap,
|
|
const bool return_softmax,
|
|
c10::optional<at::Generator> gen_);
|
|
|
|
std::vector<at::Tensor>
|
|
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
|
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &cu_seqlens_q, // b+1
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
|
|
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
|
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
|
int max_seqlen_q,
|
|
const int max_seqlen_k,
|
|
const float p_dropout,
|
|
const float softmax_scale,
|
|
const bool zero_tensors,
|
|
bool is_causal,
|
|
int window_size_left,
|
|
int window_size_right,
|
|
const float softcap,
|
|
const bool return_softmax,
|
|
c10::optional<at::Generator> gen_);
|
|
|
|
std::vector<at::Tensor>
|
|
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
|
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
|
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
|
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
|
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
|
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
|
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
|
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
|
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
|
const float p_dropout, // probability to drop
|
|
const float softmax_scale,
|
|
const bool is_causal,
|
|
int window_size_left,
|
|
int window_size_right,
|
|
const float softcap,
|
|
const bool deterministic,
|
|
c10::optional<at::Generator> gen_,
|
|
c10::optional<at::Tensor> &rng_state);
|
|
|
|
std::vector<at::Tensor>
|
|
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
|
|
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &out, // total_q x num_heads x head_size
|
|
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
|
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
const at::Tensor &cu_seqlens_q, // b+1
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
|
const int max_seqlen_q,
|
|
const int max_seqlen_k, // max sequence length to choose the kernel
|
|
const float p_dropout, // probability to drop
|
|
const float softmax_scale,
|
|
const bool zero_tensors,
|
|
const bool is_causal,
|
|
int window_size_left,
|
|
int window_size_right,
|
|
const float softcap,
|
|
const bool deterministic,
|
|
c10::optional<at::Generator> gen_,
|
|
c10::optional<at::Tensor> &rng_state);
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
{
|
|
m.doc() = "FlashAttention";
|
|
m.def("fwd", &mha_fwd, "Forward pass");
|
|
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
|
|
m.def("bwd", &mha_bwd, "Backward pass");
|
|
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
|
|
}
|