* Integrate ck branch of ck_tile/fa_bwd_opt * Assume dq and q share the same stride * update ck * Integrate more stride of dq_acc * Revert fwd dropout * Fix paremeter order * Integrate ck with more stride * update the limit of hdim of bwd * Check argument * Add test_flash_attn_causal * Support unpad lse * Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic * Fix stride and Kn0 * Fix CK sync issue * Fix typo * Update CK for changing of fmha_fwd_args * Add kvcache tmp * Add kvcache * Fix comment * Sync behavior with ck * Update CK to develop * remove large test case * Add kvcache test * Fix page_block_size in arg * Minor fix * Fix stride error * Update seqlen of kvcache before splitkv * Fix compile error * Fix bug of hdim is not 8x * Fit ck arg * support adaptive num_splits * add more tests * Refine test tolerance * update CK * Move override_num_splits_if_necessary into cpp * update ck * Update ck * Support different flag for different version of hip * remove coerce-illegal, becasue this is not required in FA * Update ck to fix xcratch memory * Add coerce-illegal in some version * Add compile flag for rtn rounding * remove redundant init * Using env var to switch rounding mode * update ck
77 lines
3.4 KiB
C++
77 lines
3.4 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
|
#include <torch/python.h>
|
|
#include <torch/nn/functional.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#ifdef OLD_GENERATOR_PATH
|
|
#include <ATen/CUDAGeneratorImpl.h>
|
|
#else
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#endif
|
|
|
|
|
|
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
|
|
namespace flash {
|
|
// Copy from PyTorch
|
|
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
|
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
|
|
if (arg.captured_) {
|
|
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
|
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
|
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
|
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
|
} else {
|
|
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
|
}
|
|
}
|
|
|
|
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
|
// If we have enough to almost fill the SMs, then just use 1 split
|
|
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
|
|
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
|
float max_efficiency = 0.f;
|
|
std::vector<float> efficiency;
|
|
efficiency.reserve(max_splits);
|
|
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
|
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
|
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
|
// (i.e. it's 11 splits anyway).
|
|
// So we check if the number of blocks per split is the same as the previous num_splits.
|
|
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
|
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
|
};
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
if (!is_split_eligible(num_splits)) {
|
|
efficiency.push_back(0.f);
|
|
} else {
|
|
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
|
|
float eff = n_waves / ceil(n_waves);
|
|
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
|
if (eff > max_efficiency) { max_efficiency = eff; }
|
|
efficiency.push_back(eff);
|
|
}
|
|
}
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
if (!is_split_eligible(num_splits)) { continue; }
|
|
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
|
// printf("num_splits chosen = %d\n", num_splits);
|
|
return num_splits;
|
|
}
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
|
|
|
|
} // namespace flash
|