* Support ck in fmha * Add ck submodule * Do not return lse if return_softmax == false * Use receipt to speed up ck compile time * Integrate new version of ck_tile * Support dropout for mha_fwd() * Add dropout to mha_varlen_fwd() * Update ck to develop * Extract padding function for dropout randval * Extract randval transformation function * Sync the code structure and coding style with FA * Remove this line, c++ api will handle this. Sync with test_flash_attn.py * fix compile error * Add mha_bwd * Generate dropout seed and offset from user generator * update CK * Add mha_varlen_bwd * Use same python as build flash-attn to generate ck kernel * Fix bug of group mode fwd about returning softmax lse * larger the test tollerance * Add test_flash_attn_output() and test_flash_attn_varlen_output() * Always fill softmax_lse * Remove duplicate benchmark script, since we already implement mha_bwd * Refine get value from tuple * Use default parameter for stream_config * unblock all platform * Add comment * refine the test code * Refine naming * Add unpack to namespace * Do not hardcode the warp size 64 * Add more targets * Add README * Optimize mha_fwd if seqlen_q == 1 * Support get_wheel_url for rocm * Detect rocm environment by pytorch's IS_HIP_EXTENSION * update to lastest ck * Add necessary compile flag * Sync the api with upstream FA --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Yichen Yan <oraluben@outlook.com>
39 lines
1.6 KiB
C++
39 lines
1.6 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
|
#include <torch/python.h>
|
|
#include <torch/nn/functional.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#ifdef OLD_GENERATOR_PATH
|
|
#include <ATen/CUDAGeneratorImpl.h>
|
|
#else
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#endif
|
|
|
|
|
|
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
|
|
namespace flash {
|
|
// Copy from PyTorch
|
|
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
|
|
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
|
|
if (arg.captured_) {
|
|
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
|
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
|
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
|
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
|
} else {
|
|
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
|
}
|
|
}
|
|
|
|
} // namespace flash
|