[AMD ROCm] Fix KVcache bug and improve performance (#1328)
* update ck * update ck * update ck again * update ck * use pointer as seed and offset * update CK * Remove useless "else" * Fix page-attn block table read out-of-bound --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
parent
284e2c6e5b
commit
88d1657a14
@ -1 +1 @@
|
||||
Subproject commit a9b170b54195ab667ca814f80dd5dfbf4ad772f5
|
||||
Subproject commit 13332998a4ca6dcc8cc5fcd401ca900529e5e65c
|
||||
@ -22,17 +22,17 @@
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
namespace flash {
|
||||
// Copy from PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
||||
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
|
||||
if (arg.captured_) {
|
||||
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
||||
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
||||
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
||||
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
||||
} else {
|
||||
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
||||
}
|
||||
inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)
|
||||
{
|
||||
// Imitate from PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
||||
if (arg.captured_) {
|
||||
rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);
|
||||
rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);
|
||||
} else {
|
||||
rng_state[0] = arg.seed_.val;
|
||||
rng_state[1] = arg.offset_.val;
|
||||
}
|
||||
}
|
||||
|
||||
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
||||
|
||||
@ -49,8 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
||||
at::Tensor dv,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
|
||||
{
|
||||
// q: (batch_size, seqlen_q, nheads, hdim)
|
||||
ck_tile::index_t batch_stride_q = q.stride(0);
|
||||
@ -191,7 +190,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
p_undrop,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
@ -213,7 +212,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
||||
const float /*softcap*/,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state)
|
||||
c10::optional<at::Tensor> &rng_state_)
|
||||
{
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
@ -337,21 +336,24 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
at::Tensor rng_state;
|
||||
|
||||
if (rng_state.has_value()) {
|
||||
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
drop_seed = d[0];
|
||||
drop_offset = d[1];
|
||||
if (rng_state_.has_value()) {
|
||||
rng_state = rng_state_.value();
|
||||
} else if(is_dropout) {
|
||||
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
|
||||
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
|
||||
}
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
@ -380,8 +382,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
||||
dv_expanded,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
drop_seed_offset);
|
||||
|
||||
float t = fmha_bwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
|
||||
|
||||
@ -46,8 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
at::Tensor dropout_randval,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
|
||||
{
|
||||
// q: (batch_size, seqlen_q, nheads, d)
|
||||
// k: (batch_size, seqlen_k, nheads_k, d)
|
||||
@ -137,7 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
@ -255,10 +254,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
|
||||
p = torch::empty({ 0 }, opts);
|
||||
}
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
|
||||
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
@ -266,13 +264,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
|
||||
}
|
||||
|
||||
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
|
||||
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
@ -305,8 +302,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
|
||||
p,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
drop_seed_offset);
|
||||
|
||||
float t = fmha_fwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
|
||||
|
||||
@ -51,8 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
||||
at::Tensor dv,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
|
||||
{
|
||||
ck_tile::index_t total_q = q.size(0);
|
||||
ck_tile::index_t total_k = k.size(0);
|
||||
@ -197,7 +196,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
p_undrop,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
@ -224,7 +223,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
|
||||
const float /*softcap*/,
|
||||
const bool deterministic,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state)
|
||||
c10::optional<at::Tensor> &rng_state_)
|
||||
{
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
@ -362,21 +361,26 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
at::Tensor rng_state;
|
||||
|
||||
if (rng_state.has_value()) {
|
||||
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
drop_seed = d[0];
|
||||
drop_offset = d[1];
|
||||
if (rng_state_.has_value()) {
|
||||
rng_state = rng_state_.value();
|
||||
} else if(is_dropout) {
|
||||
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
|
||||
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
|
||||
} else {
|
||||
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
|
||||
}
|
||||
|
||||
if (max_seqlen_q > 0) {
|
||||
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
@ -407,8 +411,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
|
||||
dv_expanded,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
drop_seed_offset);
|
||||
|
||||
float t = fmha_bwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
|
||||
|
||||
@ -47,8 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
at::Tensor dropout_randval,
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
|
||||
{
|
||||
// q: (total_q, nheads, d)
|
||||
// k: (total_k, nheads_k, d)
|
||||
@ -140,7 +139,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
@ -281,10 +280,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
|
||||
if (return_dropout_randval) {p.zero_();}
|
||||
}
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
|
||||
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
@ -292,13 +290,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
|
||||
}
|
||||
|
||||
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
|
||||
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
@ -332,8 +329,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
|
||||
p,
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed,
|
||||
drop_offset);
|
||||
drop_seed_offset);
|
||||
|
||||
float t = fmha_fwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user