Support AMD ROCm on FlashAttention 2 (#1010)

* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
This commit is contained in:
rocking 2024-07-23 12:34:37 +08:00 committed by GitHub
parent dfe1a59e4b
commit d8f104e97a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2581 additions and 13 deletions

3
.gitmodules vendored
View File

@ -1,3 +1,6 @@
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git

View File

@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue!
## AMD GPU/ROCm Support
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
## Installation and features
Requirements:
- ROCm 6.0+
- PyTorch 1.12.1+
We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.
To compile from source:
```sh
python setup.py install
```
FlashAttention-2 on ROCm currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
## Tests
To run the tests:
```sh
pytest tests/test_flash_attn_ck.py
```
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:

@ -0,0 +1 @@
Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1

View File

@ -0,0 +1,99 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
std::vector<at::Tensor>
mha_fwd(at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &alibi_slopes_,
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
}

View File

@ -0,0 +1,38 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}
} // namespace flash

View File

@ -0,0 +1,379 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
}
fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
c10::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
// k: (batch_size, seqlen_k, nheads_k, hdim)
// v: (batch_size, seqlen_k, nheads_k, hdim)
// o: (batch_size, seqlen_q, nheads, hdim)
// dq: (batch_size, seqlen_q, nheads, hdim)
// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
// do: (batch_size, seqlen_q, nheads, hdim)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// d: (batch_size, nheads, seqlen_q)
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t stride_do = dout.stride(1);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_do = dout.stride(2);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
ck_tile::index_t batch_stride_dk = dk.stride(0);
ck_tile::index_t batch_stride_dv = dv.stride(0);
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentHIPStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = dout.size(3); // unpadded hdim
const int head_size_8x = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}
// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
} else {
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
dout_padded = dout;
}
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
} else if(is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
auto args =
get_ck_fmha_bwd_args(
mask,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q,
k,
v,
alibi_slopes_,
out,
softmax_lse,
dout_padded,
softmax_d,
dq,
dk_expanded,
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_bwd(traits, args, stream_config);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
}
if (head_size_og % 8 != 0) {
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
}

View File

@ -0,0 +1,348 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_fwd.hpp"
#include "mask.hpp"
fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
{
return fmha_fwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
}
fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
bool has_dropout_randval,
const mask_info &mask,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
c10::optional<at::Tensor> &alibi_slopes_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, d)
// k: (batch_size, seqlen_k, nheads_k, d)
// v: (batch_size, seqlen_k, nheads_k, d)
// o: (batch_size, seqlen_q, nheads, d)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// randval: (batch_size, nheads, seqlen_q, seqlen_k)
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;
ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr,
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
mask_info mask;
if (is_causal) {
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
}
else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
}
else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_8x = round_multiple(head_size_og, 8);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
at::Tensor p;
if (return_dropout_randval) {
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =
get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
auto args =
get_ck_fmha_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q_padded,
k_padded,
v_padded,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_fwd(traits, args, stream_config);
}
else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

View File

@ -0,0 +1,406 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
true, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
}
fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
// sizes
const int b,
const int max_seqlen_q,
const int max_seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
c10::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (total_q, nheads, hdim)
// k: (total_k, nheads_k, hdim)
// v: (total_k, nheads_k, hdim)
// o: (total_q, nheads, hdim)
// dq: (total_q, nheads, hdim)
// dk_expanded: (total_k, nheads, hdim)
// dv_expanded: (total_k, nheads, hdim)
// do: (total_q, nheads, hdim)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, max_seqlen_q)
// d: (batch_size, nheads, max_seqlen_q)
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
ck_tile::index_t stride_q = q.stride(0);
ck_tile::index_t stride_k = k.stride(0);
ck_tile::index_t stride_v = v.stride(0);
ck_tile::index_t stride_o = out.stride(0);
ck_tile::index_t stride_do = dout.stride(0);
ck_tile::index_t stride_dk = dk.stride(0);
ck_tile::index_t stride_dv = dv.stride(0);
ck_tile::index_t nhead_stride_q = q.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(1);
ck_tile::index_t nhead_stride_do = dout.stride(1);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
ck_tile::index_t batch_stride_q = 0;
ck_tile::index_t batch_stride_k = 0;
ck_tile::index_t batch_stride_v = 0;
ck_tile::index_t batch_stride_o = 0;
ck_tile::index_t batch_stride_do = 0;
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
ck_tile::index_t batch_stride_dk = 0;
ck_tile::index_t batch_stride_dv = 0;
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_k_ptr
total_q,
total_k,
b,
max_seqlen_q, // max_seqlen_q
max_seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1];
const int head_size_og = dout.size(2);
const int head_size_8x = sizes[2];
const int total_k = k.size(0);
const int num_heads_k = k.size(1);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}
// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
} else {
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
dout_padded = dout;
}
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
if(zero_tensors) {
dq.zero_();
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
} else if(is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
if (max_seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
auto args =
get_ck_fmha_varlen_bwd_args(
mask,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
dout_padded,
softmax_d,
dq,
dk_expanded,
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_bwd(traits, args, stream_config);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
}
if (head_size_og % 8 != 0) {
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
}

View File

@ -0,0 +1,371 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_fwd.hpp"
#include "mask.hpp"
fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
{
return fmha_fwd_traits{head_size,
head_size,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
}
fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
bool has_dropout_randval,
const mask_info &mask,
// sizes
const int b,
const int max_seqlen_q,
const int h,
const int h_k,
const int d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
c10::optional<at::Tensor> &alibi_slopes_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (total_q, nheads, d)
// k: (total_k, nheads_k, d)
// v: (total_k, nheads_k, d)
// o: (total_q, nheads, d)
// alibi_slopes:(batch, nheads) or (nhead)
// lse: (batch, nheads, max_seqlen_q)
// randval: (nheads, total_q, max_seqlen_k)
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
ck_tile::index_t stride_q = q.stride(0);
ck_tile::index_t stride_k = k.stride(0);
ck_tile::index_t stride_v = v.stride(0);
ck_tile::index_t stride_o = out.stride(0);
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
ck_tile::index_t nhead_stride_q = q.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(1);
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
ck_tile::index_t batch_stride_q = 0;
ck_tile::index_t batch_stride_k = 0;
ck_tile::index_t batch_stride_v = 0;
ck_tile::index_t batch_stride_o = 0;
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_kpads
total_q,
total_k,
b,
max_seqlen_q,
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> & /*seqused_k*/,
c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
// TODO - Support paged_KV
const bool paged_KV = block_table_.has_value();
TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
const int max_num_blocks_per_seq = 0;
const int num_blocks = 0;
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
// TODO
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int total_q = q.size(0);
const int total_k = k.size(0);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
}
else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
}
else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_8x = round_multiple(head_size_og, 8);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
at::Tensor p;
if (return_dropout_randval) {
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
}
if (zero_tensors)
{
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_dropout_randval) {p.zero_();}
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
auto args =
get_ck_fmha_varlen_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
max_seqlen_q,
num_heads,
num_heads_k,
head_size_8x,
q_padded,
k_padded,
v_padded,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_fwd(traits, args, stream_config);
}
else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

168
setup.py
View File

@ -5,6 +5,8 @@ import warnings
import os
import re
import ast
import glob
import shutil
from pathlib import Path
from packaging.version import parse, Version
import platform
@ -22,6 +24,8 @@ from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
CUDA_HOME,
ROCM_HOME,
IS_HIP_EXTENSION,
)
@ -32,6 +36,19 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
if BUILD_TARGET == "auto":
if IS_HIP_EXTENSION:
IS_ROCM = True
else:
IS_ROCM = False
else:
if BUILD_TARGET == "cuda":
IS_ROCM = False
elif BUILD_TARGET == "rocm":
IS_ROCM = True
PACKAGE_NAME = "flash_attn"
BASE_WHEEL_URL = (
@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None:
)
def check_if_rocm_home_none(global_option: str) -> None:
if ROCM_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but hipcc was not found."
)
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
def rename_cpp_to_cu(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
def validate_and_update_archs(archs):
# List of allowed architectures
allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
# Validate if each element in archs is in allowed_archs
assert all(
arch in allowed_archs for arch in archs
), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
cmdclass = {}
ext_modules = []
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
if IS_ROCM:
subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
if not SKIP_CUDA_BUILD:
if not SKIP_CUDA_BUILD and not IS_ROCM:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
@ -250,6 +295,95 @@ if not SKIP_CUDA_BUILD:
],
)
)
elif not SKIP_CUDA_BUILD and IS_ROCM:
ck_dir = "csrc/composable_kernel"
#use codegen get code dispatch
if not os.path.exists("./build"):
os.makedirs("build")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
check_if_rocm_home_none("flash_attn")
cc_flag = []
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)
cc_flag = [f"--offload-arch={arch}" for arch in archs]
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True
sources = ["csrc/flash_attn_ck/flash_api.cpp",
"csrc/flash_attn_ck/mha_bwd.cpp",
"csrc/flash_attn_ck/mha_fwd.cpp",
"csrc/flash_attn_ck/mha_varlen_bwd.cpp",
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
f"build/fmha_*wd*.cpp"
)
rename_cpp_to_cu(sources)
renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
"csrc/flash_attn_ck/mha_bwd.cu",
"csrc/flash_attn_ck/mha_fwd.cu",
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc":
[
"-O3","-std=c++17",
"-mllvm", "-enable-post-misched=0",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-DCK_ENABLE_BF16",
"-DCK_ENABLE_BF8",
"-DCK_ENABLE_FP16",
"-DCK_ENABLE_FP32",
"-DCK_ENABLE_FP64",
"-DCK_ENABLE_FP8",
"-DCK_ENABLE_INT8",
"-DCK_USE_XDL",
"-DUSE_PROF_API=1",
"-D__HIP_PLATFORM_HCC__=1",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
]
+ generator_flag
+ cc_flag
,
}
include_dirs = [
Path(this_dir) / "csrc" / "composable_kernel" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
]
ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=renamed_sources,
extra_compile_args=extra_compile_args,
include_dirs=include_dirs,
)
)
def get_package_version():
@ -264,25 +398,33 @@ def get_package_version():
def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
if IS_ROCM:
torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
else:
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename

754
tests/test_flash_attn_ck.py Normal file
View File

@ -0,0 +1,754 @@
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from test_flash_attn import (
attn_bias_from_alibi_slopes,
convert_flash_attn_S_to_softmax,
generate_qkv,
generate_random_padding_mask,
attention_ref,
attention_kvpacked_ref,
attention_qkvpacked_ref,
)
def is_bwd_hdim_supported(d):
return d <= 128 and d % 2 == 0
def ck_randval_to_dropout_mask(randval, p):
# If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
# randval in 255 * [0, 0.7] will be kept
# If return dropout_mask >=0, value will be kept
return torch.floor(255.0 * (1 - p) - randval)
def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
""" pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]
Arguments:
S_dmask: (nheads, total_q, max_seqlen_k)
cu_seqlens_q: (b + 1)
Output:
S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)
"""
batch_size = cu_seqlens_q.numel() - 1
seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q
seqlens_q = seqlens_q[0:batch_size].tolist()
S_dmask = torch.split(S_dmask, seqlens_q, dim=1)
# [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]
masks = ()
for mask in S_dmask:
# (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)
mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)
masks = masks + (mask, )
S_dmask = torch.cat(masks, dim=1)
S_dmask = S_dmask.transpose(0, 1)
return S_dmask
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if d > 256:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
else:
alibi_slopes, attn_bias = None, None
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
# TODO - move to c++ mha_varlen_fwd()
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen,
seqlen,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
# CK does not return P. Hence, we don't test the attn here.
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(
qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
(dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# TODO - use 10 times to check, wait for ck to change dq type to f32
assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
@pytest.mark.parametrize("dropout_p", [0, 0.17])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if d > 256:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 6
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens,
max_seqlen,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
# TODO - move to c++ mha_varlen_fwd()
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen,
seqlen,
key_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
# CK does not return P. Hence, we don't test the attn here.
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(
qkv,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# TODO - use 10 times to check, wait for ck to change dq type to f32
assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize("kvpacked", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
):
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
else:
alibi_slopes, attn_bias = None, None
if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func(
q,
kv,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
out, lse, S_dmask = flash_attn_func(
q,
k,
v,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
# TODO - move to c++ mha_varlen_fwd()
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
# CK does not return P. Hence, we don't test the attn here.
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(
q,
kv,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if kvpacked:
(
dq,
dkv,
) = torch.autograd.grad(out, (q, kv), g)
dk, dv = dkv.unbind(2)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# TODO - use 10 times to check, wait for ck to change dq type to f32
assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("kvpacked", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 147),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
):
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
if kvpacked:
(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
# TODO - move to c++ mha_varlen_fwd()
S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
# CK does not return P. Hence, we don't test the attn here.
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(
q,
kv,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most 4 times the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()
g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if kvpacked:
(
dq_unpad,
dkv_unpad,
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
(
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# TODO - use 10 times to check, wait for ck to change dq type to f32
assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()