* update ck * update ck * update ck again * update ck * use pointer as seed and offset * update CK * Remove useless "else" * Fix page-attn block table read out-of-bound --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
77 lines
3.3 KiB
C++
77 lines
3.3 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
|
#include <torch/python.h>
|
|
#include <torch/nn/functional.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#ifdef OLD_GENERATOR_PATH
|
|
#include <ATen/CUDAGeneratorImpl.h>
|
|
#else
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#endif
|
|
|
|
|
|
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
|
|
namespace flash {
|
|
inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)
|
|
{
|
|
// Imitate from PyTorch
|
|
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
|
if (arg.captured_) {
|
|
rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);
|
|
rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);
|
|
} else {
|
|
rng_state[0] = arg.seed_.val;
|
|
rng_state[1] = arg.offset_.val;
|
|
}
|
|
}
|
|
|
|
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
|
// If we have enough to almost fill the SMs, then just use 1 split
|
|
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
|
|
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
|
float max_efficiency = 0.f;
|
|
std::vector<float> efficiency;
|
|
efficiency.reserve(max_splits);
|
|
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
|
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
|
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
|
// (i.e. it's 11 splits anyway).
|
|
// So we check if the number of blocks per split is the same as the previous num_splits.
|
|
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
|
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
|
};
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
if (!is_split_eligible(num_splits)) {
|
|
efficiency.push_back(0.f);
|
|
} else {
|
|
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
|
|
float eff = n_waves / ceil(n_waves);
|
|
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
|
if (eff > max_efficiency) { max_efficiency = eff; }
|
|
efficiency.push_back(eff);
|
|
}
|
|
}
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
if (!is_split_eligible(num_splits)) { continue; }
|
|
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
|
// printf("num_splits chosen = %d\n", num_splits);
|
|
return num_splits;
|
|
}
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
|
|
|
|
} // namespace flash
|