[AMD ROCm] Fix KVcache bug and improve performance (#1328)

* update ck

* update ck

* update ck again

* update ck

* use pointer as seed and offset

* update CK

* Remove useless "else"

* Fix page-attn block table read out-of-bound

---------

Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
rocking 2024-11-13 03:32:11 +08:00 committed by GitHub
parent 284e2c6e5b
commit 88d1657a14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 56 additions and 60 deletions

@ -1 +1 @@
Subproject commit a9b170b54195ab667ca814f80dd5dfbf4ad772f5
Subproject commit 13332998a4ca6dcc8cc5fcd401ca900529e5e65c

View File

@ -22,17 +22,17 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)
{
// Imitate from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
if (arg.captured_) {
rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);
rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);
} else {
rng_state[0] = arg.seed_.val;
rng_state[1] = arg.offset_.val;
}
}
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {

View File

@ -49,8 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_q = q.stride(0);
@ -191,7 +190,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}
std::vector<at::Tensor>
@ -213,7 +212,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
c10::optional<at::Tensor> &rng_state_)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
@ -337,21 +336,24 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
at::Tensor rng_state;
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
if (rng_state_.has_value()) {
rng_state = rng_state_.value();
} else if(is_dropout) {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
}
if (seqlen_q > 0) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};
auto traits =
@ -380,8 +382,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);
float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");

View File

@ -46,8 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (batch_size, seqlen_q, nheads, d)
// k: (batch_size, seqlen_k, nheads_k, d)
@ -137,7 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
drop_seed_offset};
}
std::vector<at::Tensor>
@ -255,10 +254,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
p = torch::empty({ 0 }, opts);
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
@ -266,13 +264,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
@ -305,8 +302,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);
float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");

View File

@ -51,8 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
@ -197,7 +196,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}
std::vector<at::Tensor>
@ -224,7 +223,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
c10::optional<at::Tensor> &rng_state_)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
@ -362,21 +361,26 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
at::Tensor rng_state;
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
if (rng_state_.has_value()) {
rng_state = rng_state_.value();
} else if(is_dropout) {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
} else {
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
}
if (max_seqlen_q > 0) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};
auto traits =
@ -407,8 +411,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);
float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");

View File

@ -47,8 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (total_q, nheads, d)
// k: (total_k, nheads_k, d)
@ -140,7 +139,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
drop_seed_offset};
}
std::vector<at::Tensor>
@ -281,10 +280,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
if (return_dropout_randval) {p.zero_();}
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
@ -292,13 +290,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (max_seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
@ -332,8 +329,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
drop_seed_offset);
float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");