Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha * Add ck submodule * Do not return lse if return_softmax == false * Use receipt to speed up ck compile time * Integrate new version of ck_tile * Support dropout for mha_fwd() * Add dropout to mha_varlen_fwd() * Update ck to develop * Extract padding function for dropout randval * Extract randval transformation function * Sync the code structure and coding style with FA * Remove this line, c++ api will handle this. Sync with test_flash_attn.py * fix compile error * Add mha_bwd * Generate dropout seed and offset from user generator * update CK * Add mha_varlen_bwd * Use same python as build flash-attn to generate ck kernel * Fix bug of group mode fwd about returning softmax lse * larger the test tollerance * Add test_flash_attn_output() and test_flash_attn_varlen_output() * Always fill softmax_lse * Remove duplicate benchmark script, since we already implement mha_bwd * Refine get value from tuple * Use default parameter for stream_config * unblock all platform * Add comment * refine the test code * Refine naming * Add unpack to namespace * Do not hardcode the warp size 64 * Add more targets * Add README * Optimize mha_fwd if seqlen_q == 1 * Support get_wheel_url for rocm * Detect rocm environment by pytorch's IS_HIP_EXTENSION * update to lastest ck * Add necessary compile flag * Sync the api with upstream FA --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Yichen Yan <oraluben@outlook.com>
This commit is contained in:
parent
dfe1a59e4b
commit
d8f104e97a
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +1,6 @@
|
||||
[submodule "csrc/cutlass"]
|
||||
path = csrc/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "csrc/composable_kernel"]
|
||||
path = csrc/composable_kernel
|
||||
url = https://github.com/ROCm/composable_kernel.git
|
||||
|
||||
27
README.md
27
README.md
@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style
|
||||
models, mostly on A100 GPUs.
|
||||
|
||||
If you encounter bugs, please open a GitHub Issue!
|
||||
## AMD GPU/ROCm Support
|
||||
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
|
||||
|
||||
## Installation and features
|
||||
Requirements:
|
||||
- ROCm 6.0+
|
||||
- PyTorch 1.12.1+
|
||||
|
||||
We recommend the
|
||||
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
||||
container from ROCm, which has all the required tools to install FlashAttention.
|
||||
|
||||
To compile from source:
|
||||
```sh
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
FlashAttention-2 on ROCm currently supports:
|
||||
1. MI200 or MI300 GPUs.
|
||||
2. Datatype fp16 and bf16
|
||||
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
|
||||
|
||||
## Tests
|
||||
To run the tests:
|
||||
```sh
|
||||
pytest tests/test_flash_attn_ck.py
|
||||
```
|
||||
|
||||
## Citation
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
|
||||
1
csrc/composable_kernel
Submodule
1
csrc/composable_kernel
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1
|
||||
99
csrc/flash_attn_ck/flash_api.cpp
Normal file
99
csrc/flash_attn_ck/flash_api.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include "flash_common.hpp"
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
c10::optional<at::Tensor> &out_,
|
||||
c10::optional<at::Tensor> &alibi_slopes_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
|
||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.doc() = "FlashAttention";
|
||||
m.def("fwd", &mha_fwd, "Forward pass");
|
||||
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
|
||||
m.def("bwd", &mha_bwd, "Backward pass");
|
||||
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
|
||||
}
|
||||
38
csrc/flash_attn_ck/flash_common.hpp
Normal file
38
csrc/flash_attn_ck/flash_common.hpp
Normal file
@ -0,0 +1,38 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
namespace flash {
|
||||
// Copy from PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
||||
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
|
||||
if (arg.captured_) {
|
||||
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
||||
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
||||
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
||||
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
||||
} else {
|
||||
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
379
csrc/flash_attn_ck/mha_bwd.cpp
Normal file
379
csrc/flash_attn_ck/mha_bwd.cpp
Normal file
@ -0,0 +1,379 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include "flash_common.hpp"
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
|
||||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool enable_alibi)
|
||||
{
|
||||
return fmha_bwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
false, // has_dbias
|
||||
has_dropout};
|
||||
}
|
||||
|
||||
fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
||||
// sizes
|
||||
const int b,
|
||||
const int seqlen_q,
|
||||
const int seqlen_k,
|
||||
const int h,
|
||||
const int h_k,
|
||||
const int hdim,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
c10::optional<at::Tensor> &alibi_slopes_,
|
||||
const at::Tensor out,
|
||||
const at::Tensor softmax_lse,
|
||||
const at::Tensor dout,
|
||||
at::Tensor d,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
// q: (batch_size, seqlen_q, nheads, hdim)
|
||||
// k: (batch_size, seqlen_k, nheads_k, hdim)
|
||||
// v: (batch_size, seqlen_k, nheads_k, hdim)
|
||||
// o: (batch_size, seqlen_q, nheads, hdim)
|
||||
// dq: (batch_size, seqlen_q, nheads, hdim)
|
||||
// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
|
||||
// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
|
||||
// do: (batch_size, seqlen_q, nheads, hdim)
|
||||
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
// lse: (batch_size, nheads, seqlen_q)
|
||||
// d: (batch_size, nheads, seqlen_q)
|
||||
|
||||
ck_tile::index_t stride_q = q.stride(1);
|
||||
ck_tile::index_t stride_k = k.stride(1);
|
||||
ck_tile::index_t stride_v = v.stride(1);
|
||||
ck_tile::index_t stride_o = out.stride(1);
|
||||
ck_tile::index_t stride_do = dout.stride(1);
|
||||
ck_tile::index_t stride_dk = dk.stride(1);
|
||||
ck_tile::index_t stride_dv = dv.stride(1);
|
||||
|
||||
ck_tile::index_t nhead_stride_q = q.stride(2);
|
||||
ck_tile::index_t nhead_stride_k = k.stride(2);
|
||||
ck_tile::index_t nhead_stride_v = v.stride(2);
|
||||
ck_tile::index_t nhead_stride_o = out.stride(2);
|
||||
ck_tile::index_t nhead_stride_do = dout.stride(2);
|
||||
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
|
||||
|
||||
ck_tile::index_t batch_stride_q = q.stride(0);
|
||||
ck_tile::index_t batch_stride_k = k.stride(0);
|
||||
ck_tile::index_t batch_stride_v = v.stride(0);
|
||||
ck_tile::index_t batch_stride_o = out.stride(0);
|
||||
ck_tile::index_t batch_stride_do = dout.stride(0);
|
||||
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
|
||||
ck_tile::index_t batch_stride_dk = dk.stride(0);
|
||||
ck_tile::index_t batch_stride_dv = dv.stride(0);
|
||||
|
||||
float p_undrop = 1.0 - p_dropout;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
}
|
||||
|
||||
return fmha_bwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
out.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
dout.data_ptr(),
|
||||
d.data_ptr(),
|
||||
nullptr, // rand_val
|
||||
dq.data_ptr(),
|
||||
dk.data_ptr(),
|
||||
dv.data_ptr(),
|
||||
nullptr, // dbias
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr, // seqlen_k_ptr
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
b,
|
||||
seqlen_q, // max_seqlen_q
|
||||
seqlen_k, // max_seqlen_k
|
||||
hdim, // hdim_q
|
||||
hdim, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
softmax_scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_o,
|
||||
0, // stride_randval
|
||||
stride_do,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
0, // stride_dbias, FA without bias
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_o,
|
||||
0, // nhead_stride_randval
|
||||
nhead_stride_do,
|
||||
nhead_stride_lse,
|
||||
0, // nhead_stride_dbias, FA without dbias
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0 , // batch_stride_bias, FA without bias
|
||||
batch_stride_o,
|
||||
0, // batch_stride_randval
|
||||
batch_stride_do,
|
||||
batch_stride_lse,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
0 , // batch_stride_dbias, FA without dbias
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
p_undrop,
|
||||
false, // s_randval
|
||||
{drop_seed, drop_offset}};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float /*softcap*/,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state)
|
||||
{
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
|
||||
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
const int head_size_og = dout.size(3); // unpadded hdim
|
||||
const int head_size_8x = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
mask_info mask;
|
||||
if (is_causal) {
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
|
||||
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
|
||||
}
|
||||
else if (window_size_left == -1 && window_size_right == -1) {
|
||||
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
|
||||
}
|
||||
else {
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
|
||||
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
|
||||
}
|
||||
|
||||
// q, k, v, out had been padded in mha_fwd
|
||||
// dq_, dk_, dv_ are also padded tensor
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
|
||||
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
CHECK_DEVICE(dq);
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
|
||||
} else {
|
||||
dq = torch::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
CHECK_DEVICE(dk);
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
||||
} else {
|
||||
dk = torch::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
CHECK_DEVICE(dv);
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
||||
} else {
|
||||
dv = torch::empty_like(v);
|
||||
}
|
||||
|
||||
at::Tensor dout_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
dout_padded = dout;
|
||||
}
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
// TODO - CK does not support dq_accum
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
|
||||
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
|
||||
if (rng_state.has_value()) {
|
||||
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
drop_seed = d[0];
|
||||
drop_offset = d[1];
|
||||
} else if(is_dropout) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
}
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
dq.zero_(); // ck use atomic operation on dq
|
||||
|
||||
auto traits =
|
||||
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_bwd_args(
|
||||
mask,
|
||||
batch_size,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
num_heads,
|
||||
num_heads_k,
|
||||
head_size_8x,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
alibi_slopes_,
|
||||
out,
|
||||
softmax_lse,
|
||||
dout_padded,
|
||||
softmax_d,
|
||||
dq,
|
||||
dk_expanded,
|
||||
dv_expanded,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
|
||||
fmha_bwd(traits, args, stream_config);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
|
||||
}
|
||||
if (head_size_og % 8 != 0) {
|
||||
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
348
csrc/flash_attn_ck/mha_fwd.cpp
Normal file
348
csrc/flash_attn_ck/mha_fwd.cpp
Normal file
@ -0,0 +1,348 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include "flash_common.hpp"
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
||||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool has_lse,
|
||||
bool enable_alibi)
|
||||
{
|
||||
return fmha_fwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
has_lse,
|
||||
has_dropout,
|
||||
false}; // do_fp8_static_quant
|
||||
}
|
||||
|
||||
fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
bool has_dropout_randval,
|
||||
const mask_info &mask,
|
||||
// sizes
|
||||
const int b,
|
||||
const int seqlen_q,
|
||||
const int seqlen_k,
|
||||
const int h,
|
||||
const int h_k,
|
||||
const int d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
c10::optional<at::Tensor> &alibi_slopes_,
|
||||
at::Tensor out,
|
||||
at::Tensor softmax_lse,
|
||||
at::Tensor dropout_randval,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
// q: (batch_size, seqlen_q, nheads, d)
|
||||
// k: (batch_size, seqlen_k, nheads_k, d)
|
||||
// v: (batch_size, seqlen_k, nheads_k, d)
|
||||
// o: (batch_size, seqlen_q, nheads, d)
|
||||
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
// lse: (batch_size, nheads, seqlen_q)
|
||||
// randval: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
|
||||
ck_tile::index_t stride_q = q.stride(1);
|
||||
ck_tile::index_t stride_k = k.stride(1);
|
||||
ck_tile::index_t stride_v = v.stride(1);
|
||||
ck_tile::index_t stride_o = out.stride(1);
|
||||
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;
|
||||
|
||||
ck_tile::index_t nhead_stride_q = q.stride(2);
|
||||
ck_tile::index_t nhead_stride_k = k.stride(2);
|
||||
ck_tile::index_t nhead_stride_v = v.stride(2);
|
||||
ck_tile::index_t nhead_stride_o = out.stride(2);
|
||||
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
|
||||
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
|
||||
|
||||
ck_tile::index_t batch_stride_q = q.stride(0);
|
||||
ck_tile::index_t batch_stride_k = k.stride(0);
|
||||
ck_tile::index_t batch_stride_v = v.stride(0);
|
||||
ck_tile::index_t batch_stride_o = out.stride(0);
|
||||
|
||||
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
|
||||
ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
|
||||
nullptr, // lse_acc
|
||||
nullptr, // o_acc
|
||||
has_lse ? softmax_lse.data_ptr() : nullptr,
|
||||
out.data_ptr(),
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
b,
|
||||
seqlen_q, // max_seqlen_q
|
||||
d, // hdim_q
|
||||
d, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
1, // num_splits
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_randval,
|
||||
0, // stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
0, // nhead_stride_lse_acc
|
||||
0, // nhead_stride_o_acc
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0, // batch_stride_bias, FA without bias
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
0, // batch_stride_lse_acc
|
||||
0, // batch_stride_o_acc
|
||||
batch_stride_o,
|
||||
0, // split_stride_lse_acc
|
||||
0, // split_stride_o_acc
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float /*softcap*/,
|
||||
const bool return_dropout_randval,
|
||||
c10::optional<at::Generator> gen_)
|
||||
{
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
int seqlen_q = sizes[1];
|
||||
int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
// causal=true is the same as causal=false in this case
|
||||
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
||||
|
||||
mask_info mask;
|
||||
if (is_causal) {
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
window_size_right = 0;
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
|
||||
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
|
||||
}
|
||||
else if (window_size_left == -1 && window_size_right == -1) {
|
||||
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
|
||||
}
|
||||
else {
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
|
||||
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
|
||||
}
|
||||
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
||||
seqlen_q = ngroups;
|
||||
num_heads = num_heads_k;
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
}
|
||||
else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
||||
}
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
}
|
||||
else {
|
||||
out = torch::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_8x = round_multiple(head_size_og, 8);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
bool has_dropout = p_dropout > 0.0f;
|
||||
|
||||
at::Tensor softmax_lse;
|
||||
// TODO - check gradient, only training require lse
|
||||
softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
|
||||
|
||||
at::Tensor p;
|
||||
if (return_dropout_randval) {
|
||||
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
|
||||
p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
|
||||
}
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
}
|
||||
|
||||
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
|
||||
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_fwd_args(
|
||||
has_lse,
|
||||
return_dropout_randval,
|
||||
mask,
|
||||
batch_size,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
num_heads,
|
||||
num_heads_k,
|
||||
head_size_8x,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
alibi_slopes_,
|
||||
out,
|
||||
softmax_lse,
|
||||
p,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
|
||||
fmha_fwd(traits, args, stream_config);
|
||||
}
|
||||
else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
out.zero_();
|
||||
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
at::Tensor out_padded = out;
|
||||
if (head_size_og % 8 != 0) {
|
||||
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
406
csrc/flash_attn_ck/mha_varlen_bwd.cpp
Normal file
406
csrc/flash_attn_ck/mha_varlen_bwd.cpp
Normal file
@ -0,0 +1,406 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include "flash_common.hpp"
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
|
||||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool enable_alibi)
|
||||
{
|
||||
return fmha_bwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
false, // has_dbias
|
||||
has_dropout};
|
||||
}
|
||||
|
||||
fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
||||
// sizes
|
||||
const int b,
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const int h,
|
||||
const int h_k,
|
||||
const int hdim,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor seqlens_q,
|
||||
const at::Tensor seqlens_k,
|
||||
c10::optional<at::Tensor> &alibi_slopes_,
|
||||
const at::Tensor out,
|
||||
const at::Tensor softmax_lse,
|
||||
const at::Tensor dout,
|
||||
at::Tensor d,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
// q: (total_q, nheads, hdim)
|
||||
// k: (total_k, nheads_k, hdim)
|
||||
// v: (total_k, nheads_k, hdim)
|
||||
// o: (total_q, nheads, hdim)
|
||||
// dq: (total_q, nheads, hdim)
|
||||
// dk_expanded: (total_k, nheads, hdim)
|
||||
// dv_expanded: (total_k, nheads, hdim)
|
||||
// do: (total_q, nheads, hdim)
|
||||
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
// lse: (batch_size, nheads, max_seqlen_q)
|
||||
// d: (batch_size, nheads, max_seqlen_q)
|
||||
|
||||
ck_tile::index_t total_q = q.size(0);
|
||||
ck_tile::index_t total_k = k.size(0);
|
||||
|
||||
ck_tile::index_t stride_q = q.stride(0);
|
||||
ck_tile::index_t stride_k = k.stride(0);
|
||||
ck_tile::index_t stride_v = v.stride(0);
|
||||
ck_tile::index_t stride_o = out.stride(0);
|
||||
ck_tile::index_t stride_do = dout.stride(0);
|
||||
ck_tile::index_t stride_dk = dk.stride(0);
|
||||
ck_tile::index_t stride_dv = dv.stride(0);
|
||||
|
||||
ck_tile::index_t nhead_stride_q = q.stride(1);
|
||||
ck_tile::index_t nhead_stride_k = k.stride(1);
|
||||
ck_tile::index_t nhead_stride_v = v.stride(1);
|
||||
ck_tile::index_t nhead_stride_o = out.stride(1);
|
||||
ck_tile::index_t nhead_stride_do = dout.stride(1);
|
||||
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
|
||||
|
||||
ck_tile::index_t batch_stride_q = 0;
|
||||
ck_tile::index_t batch_stride_k = 0;
|
||||
ck_tile::index_t batch_stride_v = 0;
|
||||
ck_tile::index_t batch_stride_o = 0;
|
||||
ck_tile::index_t batch_stride_do = 0;
|
||||
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
|
||||
ck_tile::index_t batch_stride_dk = 0;
|
||||
ck_tile::index_t batch_stride_dv = 0;
|
||||
|
||||
float p_undrop = 1.0 - p_dropout;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
}
|
||||
|
||||
return fmha_bwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
out.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
dout.data_ptr(),
|
||||
d.data_ptr(),
|
||||
nullptr, // rand_val
|
||||
dq.data_ptr(),
|
||||
dk.data_ptr(),
|
||||
dv.data_ptr(),
|
||||
nullptr, // dbias
|
||||
seqlens_q.data_ptr(), // seqstart_q
|
||||
seqlens_k.data_ptr(), // seqstart_k
|
||||
nullptr, // seqlen_k_ptr
|
||||
total_q,
|
||||
total_k,
|
||||
b,
|
||||
max_seqlen_q, // max_seqlen_q
|
||||
max_seqlen_k, // max_seqlen_k
|
||||
hdim, // hdim_q
|
||||
hdim, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
softmax_scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_o,
|
||||
0, // stride_randval
|
||||
stride_do,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
0, // stride_dbias, FA without bias
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_o,
|
||||
0, // nhead_stride_randval
|
||||
nhead_stride_do,
|
||||
nhead_stride_lse,
|
||||
0, // nhead_stride_dbias, FA without dbias
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0 , // batch_stride_bias, FA without bias
|
||||
batch_stride_o,
|
||||
0, // batch_stride_randval
|
||||
batch_stride_do,
|
||||
batch_stride_lse,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
0 , // batch_stride_dbias, FA without dbias
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
p_undrop,
|
||||
false, // s_randval
|
||||
{drop_seed, drop_offset}};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float /*softcap*/,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state)
|
||||
{
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
const int head_size_og = dout.size(2);
|
||||
const int head_size_8x = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
mask_info mask;
|
||||
if (is_causal) {
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
|
||||
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
|
||||
}
|
||||
else if (window_size_left == -1 && window_size_right == -1) {
|
||||
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
|
||||
}
|
||||
else {
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
|
||||
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
|
||||
}
|
||||
|
||||
// q, k, v, out had been padded in mha_fwd
|
||||
// dq_, dk_, dv_ are also padded tensor
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
CHECK_DEVICE(dq);
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
|
||||
} else {
|
||||
dq = torch::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
CHECK_DEVICE(dk);
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
|
||||
} else {
|
||||
dk = torch::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
CHECK_DEVICE(dv);
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
|
||||
} else {
|
||||
dv = torch::empty_like(v);
|
||||
}
|
||||
|
||||
at::Tensor dout_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
dout_padded = dout;
|
||||
}
|
||||
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
// TODO - CK does not support dq_accum
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
|
||||
dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
if(zero_tensors) {
|
||||
dq.zero_();
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
|
||||
if (rng_state.has_value()) {
|
||||
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
drop_seed = d[0];
|
||||
drop_offset = d[1];
|
||||
} else if(is_dropout) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
}
|
||||
|
||||
if (max_seqlen_q > 0) {
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
dq.zero_(); // ck use atomic operation on dq
|
||||
|
||||
auto traits =
|
||||
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_varlen_bwd_args(
|
||||
mask,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
num_heads_k,
|
||||
head_size_8x,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
out,
|
||||
softmax_lse,
|
||||
dout_padded,
|
||||
softmax_d,
|
||||
dq,
|
||||
dk_expanded,
|
||||
dv_expanded,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
|
||||
fmha_bwd(traits, args, stream_config);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
|
||||
}
|
||||
if (head_size_og % 8 != 0) {
|
||||
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
371
csrc/flash_attn_ck/mha_varlen_fwd.cpp
Normal file
371
csrc/flash_attn_ck/mha_varlen_fwd.cpp
Normal file
@ -0,0 +1,371 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include "flash_common.hpp"
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
||||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool has_lse,
|
||||
bool enable_alibi)
|
||||
{
|
||||
return fmha_fwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
has_lse,
|
||||
has_dropout,
|
||||
false}; // do_fp8_static_quant
|
||||
}
|
||||
|
||||
fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
bool has_dropout_randval,
|
||||
const mask_info &mask,
|
||||
// sizes
|
||||
const int b,
|
||||
const int max_seqlen_q,
|
||||
const int h,
|
||||
const int h_k,
|
||||
const int d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor seqlens_q,
|
||||
const at::Tensor seqlens_k,
|
||||
c10::optional<at::Tensor> &alibi_slopes_,
|
||||
at::Tensor out,
|
||||
at::Tensor softmax_lse,
|
||||
at::Tensor dropout_randval,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
// q: (total_q, nheads, d)
|
||||
// k: (total_k, nheads_k, d)
|
||||
// v: (total_k, nheads_k, d)
|
||||
// o: (total_q, nheads, d)
|
||||
|
||||
// alibi_slopes:(batch, nheads) or (nhead)
|
||||
// lse: (batch, nheads, max_seqlen_q)
|
||||
// randval: (nheads, total_q, max_seqlen_k)
|
||||
|
||||
ck_tile::index_t total_q = q.size(0);
|
||||
ck_tile::index_t total_k = k.size(0);
|
||||
|
||||
ck_tile::index_t stride_q = q.stride(0);
|
||||
ck_tile::index_t stride_k = k.stride(0);
|
||||
ck_tile::index_t stride_v = v.stride(0);
|
||||
ck_tile::index_t stride_o = out.stride(0);
|
||||
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
|
||||
|
||||
ck_tile::index_t nhead_stride_q = q.stride(1);
|
||||
ck_tile::index_t nhead_stride_k = k.stride(1);
|
||||
ck_tile::index_t nhead_stride_v = v.stride(1);
|
||||
ck_tile::index_t nhead_stride_o = out.stride(1);
|
||||
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
|
||||
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
|
||||
|
||||
ck_tile::index_t batch_stride_q = 0;
|
||||
ck_tile::index_t batch_stride_k = 0;
|
||||
ck_tile::index_t batch_stride_v = 0;
|
||||
ck_tile::index_t batch_stride_o = 0;
|
||||
|
||||
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
|
||||
nullptr, // lse_acc
|
||||
nullptr, // o_acc
|
||||
has_lse ? softmax_lse.data_ptr() : nullptr,
|
||||
out.data_ptr(),
|
||||
seqlens_q.data_ptr(), // seqstart_q
|
||||
seqlens_k.data_ptr(), // seqstart_k
|
||||
nullptr, // seqlen_kpads
|
||||
total_q,
|
||||
total_k,
|
||||
b,
|
||||
max_seqlen_q,
|
||||
d, // hdim_q
|
||||
d, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
1, // num_splits
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_randval,
|
||||
0, // stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
0, // nhead_stride_lse_acc
|
||||
0, // nhead_stride_o_acc
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0, // batch_stride_bias, FA without bias
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
0, // batch_stride_lse_acc
|
||||
0, // batch_stride_o_acc
|
||||
batch_stride_o,
|
||||
0, // split_stride_lse_acc
|
||||
0, // split_stride_o_acc
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> & /*seqused_k*/,
|
||||
c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
|
||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float /*softcap*/,
|
||||
const bool return_dropout_randval,
|
||||
c10::optional<at::Generator> gen_)
|
||||
{
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(cu_seqlens_q);
|
||||
CHECK_DEVICE(cu_seqlens_k);
|
||||
|
||||
// TODO - Support paged_KV
|
||||
const bool paged_KV = block_table_.has_value();
|
||||
TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
int num_heads = sizes[1];
|
||||
const int head_size_og = sizes[2];
|
||||
const int num_heads_k = k.size(1);
|
||||
|
||||
const int max_num_blocks_per_seq = 0;
|
||||
const int num_blocks = 0;
|
||||
|
||||
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
|
||||
// TODO
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
|
||||
const int total_q = q.size(0);
|
||||
const int total_k = k.size(0);
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
mask_info mask;
|
||||
|
||||
if (is_causal) {
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
window_size_right = 0;
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
|
||||
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
|
||||
}
|
||||
else if (window_size_left == -1 && window_size_right == -1) {
|
||||
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
|
||||
}
|
||||
else {
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
|
||||
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
}
|
||||
else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
}
|
||||
else {
|
||||
out = torch::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_8x = round_multiple(head_size_og, 8);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
bool has_dropout = p_dropout > 0.0f;
|
||||
|
||||
at::Tensor softmax_lse;
|
||||
// TODO - check gradient, only training require lse
|
||||
softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
|
||||
|
||||
at::Tensor p;
|
||||
if (return_dropout_randval) {
|
||||
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
|
||||
p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
|
||||
}
|
||||
|
||||
if (zero_tensors)
|
||||
{
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_dropout_randval) {p.zero_();}
|
||||
}
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
}
|
||||
|
||||
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
|
||||
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_varlen_fwd_args(
|
||||
has_lse,
|
||||
return_dropout_randval,
|
||||
mask,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
num_heads,
|
||||
num_heads_k,
|
||||
head_size_8x,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
out,
|
||||
softmax_lse,
|
||||
p,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
|
||||
fmha_fwd(traits, args, stream_config);
|
||||
}
|
||||
else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
out.zero_();
|
||||
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
at::Tensor out_padded = out;
|
||||
if (head_size_og % 8 != 0) {
|
||||
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
168
setup.py
168
setup.py
@ -5,6 +5,8 @@ import warnings
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from packaging.version import parse, Version
|
||||
import platform
|
||||
@ -22,6 +24,8 @@ from torch.utils.cpp_extension import (
|
||||
CppExtension,
|
||||
CUDAExtension,
|
||||
CUDA_HOME,
|
||||
ROCM_HOME,
|
||||
IS_HIP_EXTENSION,
|
||||
)
|
||||
|
||||
|
||||
@ -32,6 +36,19 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
|
||||
|
||||
if BUILD_TARGET == "auto":
|
||||
if IS_HIP_EXTENSION:
|
||||
IS_ROCM = True
|
||||
else:
|
||||
IS_ROCM = False
|
||||
else:
|
||||
if BUILD_TARGET == "cuda":
|
||||
IS_ROCM = False
|
||||
elif BUILD_TARGET == "rocm":
|
||||
IS_ROCM = True
|
||||
|
||||
PACKAGE_NAME = "flash_attn"
|
||||
|
||||
BASE_WHEEL_URL = (
|
||||
@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_if_rocm_home_none(global_option: str) -> None:
|
||||
if ROCM_HOME is not None:
|
||||
return
|
||||
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
||||
# in that case.
|
||||
warnings.warn(
|
||||
f"{global_option} was requested, but hipcc was not found."
|
||||
)
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
|
||||
def rename_cpp_to_cu(cpp_files):
|
||||
for entry in cpp_files:
|
||||
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
|
||||
|
||||
|
||||
def validate_and_update_archs(archs):
|
||||
# List of allowed architectures
|
||||
allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||
|
||||
# Validate if each element in archs is in allowed_archs
|
||||
assert all(
|
||||
arch in allowed_archs for arch in archs
|
||||
), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
|
||||
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
|
||||
# files included in the source distribution, in case the user compiles from source.
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
if IS_ROCM:
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
|
||||
else:
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
|
||||
if not SKIP_CUDA_BUILD:
|
||||
if not SKIP_CUDA_BUILD and not IS_ROCM:
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
@ -250,6 +295,95 @@ if not SKIP_CUDA_BUILD:
|
||||
],
|
||||
)
|
||||
)
|
||||
elif not SKIP_CUDA_BUILD and IS_ROCM:
|
||||
ck_dir = "csrc/composable_kernel"
|
||||
|
||||
#use codegen get code dispatch
|
||||
if not os.path.exists("./build"):
|
||||
os.makedirs("build")
|
||||
|
||||
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
|
||||
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
|
||||
# See https://github.com/pytorch/pytorch/pull/70650
|
||||
generator_flag = []
|
||||
torch_dir = torch.__path__[0]
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
|
||||
check_if_rocm_home_none("flash_attn")
|
||||
cc_flag = []
|
||||
|
||||
archs = os.getenv("GPU_ARCHS", "native").split(";")
|
||||
validate_and_update_archs(archs)
|
||||
|
||||
cc_flag = [f"--offload-arch={arch}" for arch in archs]
|
||||
|
||||
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
||||
# torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
||||
if FORCE_CXX11_ABI:
|
||||
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
||||
|
||||
sources = ["csrc/flash_attn_ck/flash_api.cpp",
|
||||
"csrc/flash_attn_ck/mha_bwd.cpp",
|
||||
"csrc/flash_attn_ck/mha_fwd.cpp",
|
||||
"csrc/flash_attn_ck/mha_varlen_bwd.cpp",
|
||||
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
|
||||
f"build/fmha_*wd*.cpp"
|
||||
)
|
||||
|
||||
rename_cpp_to_cu(sources)
|
||||
|
||||
renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
|
||||
"csrc/flash_attn_ck/mha_bwd.cu",
|
||||
"csrc/flash_attn_ck/mha_fwd.cu",
|
||||
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
|
||||
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
"nvcc":
|
||||
[
|
||||
"-O3","-std=c++17",
|
||||
"-mllvm", "-enable-post-misched=0",
|
||||
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
|
||||
"-fgpu-flush-denormals-to-zero",
|
||||
"-DCK_ENABLE_BF16",
|
||||
"-DCK_ENABLE_BF8",
|
||||
"-DCK_ENABLE_FP16",
|
||||
"-DCK_ENABLE_FP32",
|
||||
"-DCK_ENABLE_FP64",
|
||||
"-DCK_ENABLE_FP8",
|
||||
"-DCK_ENABLE_INT8",
|
||||
"-DCK_USE_XDL",
|
||||
"-DUSE_PROF_API=1",
|
||||
"-D__HIP_PLATFORM_HCC__=1",
|
||||
# "-DFLASHATTENTION_DISABLE_BACKWARD",
|
||||
]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
,
|
||||
}
|
||||
|
||||
include_dirs = [
|
||||
Path(this_dir) / "csrc" / "composable_kernel" / "include",
|
||||
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
|
||||
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
|
||||
]
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_attn_2_cuda",
|
||||
sources=renamed_sources,
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_package_version():
|
||||
@ -264,25 +398,33 @@ def get_package_version():
|
||||
|
||||
|
||||
def get_wheel_url():
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
# We're using the CUDA version used to build torch, not the one currently installed
|
||||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_cuda_version = parse(torch.version.cuda)
|
||||
torch_version_raw = parse(torch.__version__)
|
||||
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
|
||||
# to save CI time. Minor versions should be compatible.
|
||||
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
flash_version = get_package_version()
|
||||
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
|
||||
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
||||
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
||||
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||
if IS_ROCM:
|
||||
torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
|
||||
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
||||
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||
else:
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
# We're using the CUDA version used to build torch, not the one currently installed
|
||||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_cuda_version = parse(torch.version.cuda)
|
||||
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
|
||||
# to save CI time. Minor versions should be compatible.
|
||||
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
|
||||
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
|
||||
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||
|
||||
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
|
||||
|
||||
return wheel_url, wheel_filename
|
||||
|
||||
|
||||
|
||||
754
tests/test_flash_attn_ck.py
Normal file
754
tests/test_flash_attn_ck.py
Normal file
@ -0,0 +1,754 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from flash_attn import (
|
||||
flash_attn_func,
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_qkvpacked_func,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
from test_flash_attn import (
|
||||
attn_bias_from_alibi_slopes,
|
||||
convert_flash_attn_S_to_softmax,
|
||||
generate_qkv,
|
||||
generate_random_padding_mask,
|
||||
attention_ref,
|
||||
attention_kvpacked_ref,
|
||||
attention_qkvpacked_ref,
|
||||
)
|
||||
|
||||
def is_bwd_hdim_supported(d):
|
||||
return d <= 128 and d % 2 == 0
|
||||
|
||||
|
||||
def ck_randval_to_dropout_mask(randval, p):
|
||||
# If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
|
||||
# randval in 255 * [0, 0.7] will be kept
|
||||
# If return dropout_mask >=0, value will be kept
|
||||
return torch.floor(255.0 * (1 - p) - randval)
|
||||
|
||||
|
||||
def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
|
||||
""" pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]
|
||||
Arguments:
|
||||
S_dmask: (nheads, total_q, max_seqlen_k)
|
||||
cu_seqlens_q: (b + 1)
|
||||
Output:
|
||||
S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)
|
||||
"""
|
||||
batch_size = cu_seqlens_q.numel() - 1
|
||||
seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q
|
||||
seqlens_q = seqlens_q[0:batch_size].tolist()
|
||||
S_dmask = torch.split(S_dmask, seqlens_q, dim=1)
|
||||
# [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]
|
||||
masks = ()
|
||||
for mask in S_dmask:
|
||||
# (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)
|
||||
mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)
|
||||
masks = masks + (mask, )
|
||||
S_dmask = torch.cat(masks, dim=1)
|
||||
|
||||
S_dmask = S_dmask.transpose(0, 1)
|
||||
return S_dmask
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
@pytest.mark.parametrize("alibi", [False, True])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
|
||||
if d > 256:
|
||||
pytest.skip()
|
||||
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
nheads = 9
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
|
||||
|
||||
qkv = torch.randn(
|
||||
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
if alibi:
|
||||
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
||||
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
|
||||
else:
|
||||
alibi_slopes, attn_bias = None, None
|
||||
out, lse, S_dmask = flash_attn_qkvpacked_func(
|
||||
qkv,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
if dropout_p > 0.0:
|
||||
# TODO - move to c++ mha_varlen_fwd()
|
||||
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask,
|
||||
seqlen,
|
||||
seqlen,
|
||||
None,
|
||||
None,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
# CK does not return P. Hence, we don't test the attn here.
|
||||
else:
|
||||
dropout_mask = None
|
||||
|
||||
out_ref, attn_ref = attention_qkvpacked_ref(
|
||||
qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
|
||||
)
|
||||
out_pt, attn_pt = attention_qkvpacked_ref(
|
||||
qkv,
|
||||
None,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if is_bwd_hdim_supported(d):
|
||||
(dqkv,) = torch.autograd.grad(out, qkv, g)
|
||||
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
||||
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
|
||||
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
|
||||
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
|
||||
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
|
||||
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
|
||||
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
|
||||
|
||||
# TODO - use 10 times to check, wait for ck to change dq type to f32
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
@pytest.mark.parametrize("alibi", [False, True])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
||||
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
|
||||
@pytest.mark.parametrize("dropout_p", [0, 0.17])
|
||||
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
|
||||
if d > 256:
|
||||
pytest.skip()
|
||||
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 5
|
||||
nheads = 6
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
|
||||
qkv = torch.randn(
|
||||
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
|
||||
if alibi:
|
||||
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
||||
attn_bias = attn_bias_from_alibi_slopes(
|
||||
alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
|
||||
)
|
||||
else:
|
||||
alibi_slopes, attn_bias = None, None
|
||||
|
||||
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
|
||||
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
|
||||
)
|
||||
|
||||
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
# TODO - move to c++ mha_varlen_fwd()
|
||||
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
|
||||
S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)
|
||||
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask,
|
||||
seqlen,
|
||||
seqlen,
|
||||
key_padding_mask,
|
||||
key_padding_mask,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
# CK does not return P. Hence, we don't test the attn here.
|
||||
else:
|
||||
dropout_mask = None
|
||||
|
||||
out_ref, attn_ref = attention_qkvpacked_ref(
|
||||
qkv,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
out_pt, attn_pt = attention_qkvpacked_ref(
|
||||
qkv,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if is_bwd_hdim_supported(d):
|
||||
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
|
||||
dqkv = dqkv_pad_fn(dqkv_unpad)
|
||||
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
|
||||
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
|
||||
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
|
||||
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
|
||||
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
|
||||
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
|
||||
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
|
||||
|
||||
# TODO - use 10 times to check, wait for ck to change dq type to f32
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kvpacked", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
@pytest.mark.parametrize("alibi", [False, True])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
def test_flash_attn_output(
|
||||
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
|
||||
):
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
nheads = 9
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
assert nheads % nheads_k == 0
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
if kvpacked:
|
||||
kv = torch.randn(
|
||||
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
else:
|
||||
k = torch.randn(
|
||||
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
if alibi:
|
||||
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
||||
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
|
||||
else:
|
||||
alibi_slopes, attn_bias = None, None
|
||||
|
||||
if kvpacked:
|
||||
out, lse, S_dmask = flash_attn_kvpacked_func(
|
||||
q,
|
||||
kv,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
else:
|
||||
out, lse, S_dmask = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
if dropout_p > 0.0:
|
||||
# TODO - move to c++ mha_varlen_fwd()
|
||||
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
None,
|
||||
None,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
if kvpacked:
|
||||
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
|
||||
k_rep, v_rep = kv_rep.unbind(dim=2)
|
||||
else:
|
||||
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
# CK does not return P. Hence, we don't test the attn here.
|
||||
else:
|
||||
dropout_mask = None
|
||||
|
||||
if kvpacked:
|
||||
out_ref, attn_ref = attention_kvpacked_ref(
|
||||
q,
|
||||
kv,
|
||||
None,
|
||||
None,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
out_pt, attn_pt = attention_kvpacked_ref(
|
||||
q,
|
||||
kv,
|
||||
None,
|
||||
None,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
else:
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if is_bwd_hdim_supported(d):
|
||||
if kvpacked:
|
||||
(
|
||||
dq,
|
||||
dkv,
|
||||
) = torch.autograd.grad(out, (q, kv), g)
|
||||
dk, dv = dkv.unbind(2)
|
||||
(
|
||||
dq_ref,
|
||||
dkv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, kv), g)
|
||||
dk_ref, dv_ref = dkv_ref.unbind(2)
|
||||
(
|
||||
dq_pt,
|
||||
dkv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, kv), g)
|
||||
dk_pt, dv_pt = dkv_pt.unbind(2)
|
||||
else:
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
) = torch.autograd.grad(out, (q, k, v), g)
|
||||
(
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# TODO - use 10 times to check, wait for ck to change dq type to f32
|
||||
assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kvpacked", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
@pytest.mark.parametrize("deterministic", [False, True])
|
||||
@pytest.mark.parametrize("alibi", [False, True])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 147),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
def test_flash_attn_varlen_output(
|
||||
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
|
||||
):
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
nheads = 9
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
assert nheads % nheads_k == 0
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
if kvpacked:
|
||||
kv = torch.randn(
|
||||
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
else:
|
||||
k = torch.randn(
|
||||
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
||||
if alibi:
|
||||
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
||||
attn_bias = attn_bias_from_alibi_slopes(
|
||||
alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
|
||||
)
|
||||
else:
|
||||
alibi_slopes, attn_bias = None, None
|
||||
|
||||
if kvpacked:
|
||||
(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dkv_pad_fn,
|
||||
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
|
||||
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
else:
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
||||
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
# TODO - move to c++ mha_varlen_fwd()
|
||||
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
|
||||
S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
if kvpacked:
|
||||
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
|
||||
k_rep, v_rep = kv_rep.unbind(dim=2)
|
||||
else:
|
||||
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
# CK does not return P. Hence, we don't test the attn here.
|
||||
else:
|
||||
dropout_mask = None
|
||||
|
||||
if kvpacked:
|
||||
out_ref, attn_ref = attention_kvpacked_ref(
|
||||
q,
|
||||
kv,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
out_pt, attn_pt = attention_kvpacked_ref(
|
||||
q,
|
||||
kv,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
else:
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
attn_bias,
|
||||
dropout_p,
|
||||
dropout_mask,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most 4 times the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if is_bwd_hdim_supported(d):
|
||||
if kvpacked:
|
||||
(
|
||||
dq_unpad,
|
||||
dkv_unpad,
|
||||
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
|
||||
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
|
||||
(
|
||||
dq_ref,
|
||||
dkv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, kv), g)
|
||||
dk_ref, dv_ref = dkv_ref.unbind(2)
|
||||
(
|
||||
dq_pt,
|
||||
dkv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, kv), g)
|
||||
dk_pt, dv_pt = dkv_pt.unbind(2)
|
||||
else:
|
||||
(
|
||||
dq_unpad,
|
||||
dk_unpad,
|
||||
dv_unpad,
|
||||
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
||||
dk = dk_pad_fn(dk_unpad)
|
||||
dv = dk_pad_fn(dv_unpad)
|
||||
(
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# TODO - use 10 times to check, wait for ck to change dq type to f32
|
||||
assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
|
||||
Loading…
Reference in New Issue
Block a user