xFormer updates to fMHA FW (#773)
* xFormer updates to fMHA FW * Convert format to BMHK for '41_fused_multi_head_attention_fixed_seqlen' * Add missing files * Remove xFormers specific code * Update fused_multihead_attention_fixed_seqlen.cu * rebase and solve conflicts * remove white space --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
5ff5209ed5
commit
2e10404d26
@ -47,19 +47,52 @@
|
||||
}
|
||||
|
||||
// Print on the first thread of the first block
|
||||
#if 0
|
||||
#if 1
|
||||
#define PRINT_WARP_ID 0
|
||||
#define PRINT_LANE_ID 0
|
||||
#define PRINT_T0_L0(msg, ...) \
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
|
||||
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
||||
threadIdx.z == 0) { \
|
||||
printf(msg "\n", __VA_ARGS__); \
|
||||
printf(msg "\n", ##__VA_ARGS__); \
|
||||
}
|
||||
#define PRINT_TX_LX(msg, ...) \
|
||||
for (int bx = 0; bx < gridDim.x; ++bx) { \
|
||||
for (int by = 0; by < gridDim.y; ++by) { \
|
||||
for (int bz = 0; bz < gridDim.z; ++bz) { \
|
||||
for (int tx = 0; tx < blockDim.x; ++tx) { \
|
||||
for (int ty = 0; ty < blockDim.y; ++ty) { \
|
||||
for (int tz = 0; tz < blockDim.z; ++tz) { \
|
||||
__syncthreads(); \
|
||||
if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
|
||||
threadIdx.x == tx && threadIdx.y == ty && \
|
||||
threadIdx.z == tz) { \
|
||||
printf( \
|
||||
"[%d,%d,%d][%d,%d,%d]" msg "\n", \
|
||||
bx, \
|
||||
by, \
|
||||
bz, \
|
||||
tx, \
|
||||
ty, \
|
||||
tz, \
|
||||
##__VA_ARGS__); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define PRINT_T0_L0
|
||||
#define PRINT_TX_LX
|
||||
#endif
|
||||
|
||||
struct __string_view {
|
||||
char const* data;
|
||||
std::size_t size;
|
||||
};
|
||||
#if __cplusplus >= 201402L
|
||||
template <class T>
|
||||
constexpr __string_view __get_type_name() {
|
||||
char const* p = __PRETTY_FUNCTION__;
|
||||
@ -83,7 +116,10 @@ constexpr __string_view __get_type_name() {
|
||||
return {};
|
||||
}
|
||||
#else
|
||||
#define PRINT_T0_L0
|
||||
template <class T>
|
||||
constexpr __string_view __get_type_name() {
|
||||
return {"unsupported", 11};
|
||||
}
|
||||
#endif
|
||||
|
||||
// Print a given array
|
||||
|
@ -168,6 +168,9 @@ public:
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutO::Stride::LongIndex *ldo;
|
||||
|
||||
// Scale
|
||||
ElementAccumulator scale;
|
||||
|
||||
// Whether causal masking is to be performed
|
||||
bool causal;
|
||||
|
||||
@ -193,6 +196,7 @@ public:
|
||||
ldk(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
scale(0),
|
||||
causal(false),
|
||||
host_problem_sizes(nullptr)
|
||||
{
|
||||
@ -218,6 +222,7 @@ public:
|
||||
typename LayoutV::Stride::LongIndex *ldv,
|
||||
typename LayoutO::Stride::LongIndex *ldo,
|
||||
bool causal,
|
||||
ElementAccumulator scale,
|
||||
GemmCoord *host_problem_sizes=nullptr
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
@ -235,6 +240,7 @@ public:
|
||||
ldv(ldv),
|
||||
ldo(ldo),
|
||||
causal(causal),
|
||||
scale(scale),
|
||||
host_problem_sizes(host_problem_sizes)
|
||||
{
|
||||
|
||||
@ -273,6 +279,7 @@ public:
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutO::Stride::LongIndex *ldo;
|
||||
|
||||
ElementAccumulator scale;
|
||||
bool causal;
|
||||
|
||||
//
|
||||
@ -291,7 +298,8 @@ public:
|
||||
ldk(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
causal(false)
|
||||
causal(false),
|
||||
scale(0)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -310,8 +318,9 @@ public:
|
||||
ldk(args.ldk),
|
||||
ldv(args.ldv),
|
||||
ldo(args.ldo),
|
||||
causal(args.causal)
|
||||
{
|
||||
causal(args.causal),
|
||||
scale(args.scale)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
@ -337,6 +346,7 @@ public:
|
||||
ldv = args.ldv;
|
||||
ldo = args.ldo;
|
||||
causal = args.causal;
|
||||
scale = args.scale;
|
||||
}
|
||||
};
|
||||
|
||||
@ -649,7 +659,7 @@ public:
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(problem_size0.k())));
|
||||
params.scale);
|
||||
}));
|
||||
}));
|
||||
|
||||
|
@ -504,37 +504,51 @@ private:
|
||||
ldo_host.resize(problem_count());
|
||||
seqlen_host.resize(problem_count());
|
||||
|
||||
for (int32_t i = 0; i < problem_count(); ++i) {
|
||||
// Create tensors in BMHK format, where
|
||||
// B = batch_size
|
||||
// M = sequence length
|
||||
// H = num_heads
|
||||
// K = embedding size per head
|
||||
int64_t batch_offset_Q, batch_offset_K, batch_offset_V, batch_offset_O;
|
||||
|
||||
auto problem0 = options.problem_sizes0.at(i);
|
||||
auto problem1 = options.problem_sizes1.at(i);
|
||||
for (int32_t b = 0; b < options.batch_size; ++b) {
|
||||
batch_offset_Q = total_elements_Q;
|
||||
batch_offset_K = total_elements_K;
|
||||
batch_offset_V = total_elements_V;
|
||||
batch_offset_O = total_elements_O;
|
||||
for (int32_t h = 0; h < options.head_number; ++h) {
|
||||
int32_t i = h + b * options.head_number;
|
||||
|
||||
ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0);
|
||||
ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0);
|
||||
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
|
||||
ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0);
|
||||
ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0);
|
||||
auto problem0 = options.problem_sizes0.at(i);
|
||||
auto problem1 = options.problem_sizes1.at(i);
|
||||
|
||||
// m = n for attention problems.
|
||||
seqlen_host.at(i) = problem0.m();
|
||||
ldq_host.at(i) = LayoutQ::packed({problem0.m(), options.head_number * problem0.k()}).stride(0);
|
||||
ldk_host.at(i) = LayoutK::packed({options.head_number * problem0.k(), problem0.n()}).stride(0);
|
||||
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
|
||||
ldv_host.at(i) = LayoutV::packed({problem1.k(), options.head_number * problem1.n()}).stride(0);
|
||||
ldo_host.at(i) = LayoutO::packed({problem1.m(), options.head_number * problem1.n()}).stride(0);
|
||||
|
||||
offset_Q.push_back(total_elements_Q);
|
||||
offset_K.push_back(total_elements_K);
|
||||
offset_P.push_back(total_elements_P);
|
||||
offset_V.push_back(total_elements_V);
|
||||
offset_O.push_back(total_elements_O);
|
||||
// m = n for attention problems.
|
||||
seqlen_host.at(i) = problem0.m();
|
||||
|
||||
int64_t elements_Q = problem0.m() * problem0.k();
|
||||
int64_t elements_K = problem0.k() * problem0.n();
|
||||
int64_t elements_P = problem0.m() * problem0.n();
|
||||
int64_t elements_V = problem1.k() * problem1.n();
|
||||
int64_t elements_O = problem1.m() * problem1.n();
|
||||
offset_Q.push_back(batch_offset_Q + h * problem0.k());
|
||||
offset_K.push_back(batch_offset_K + h * problem0.k());
|
||||
offset_P.push_back(total_elements_P);
|
||||
offset_V.push_back(batch_offset_V + h * problem0.k());
|
||||
offset_O.push_back(batch_offset_O + h * problem1.n());
|
||||
|
||||
total_elements_Q += elements_Q;
|
||||
total_elements_K += elements_K;
|
||||
total_elements_P += elements_P;
|
||||
total_elements_V += elements_V;
|
||||
total_elements_O += elements_O;
|
||||
int64_t elements_Q = problem0.m() * problem0.k();
|
||||
int64_t elements_K = problem0.k() * problem0.n();
|
||||
int64_t elements_P = problem0.m() * problem0.n();
|
||||
int64_t elements_V = problem1.k() * problem1.n();
|
||||
int64_t elements_O = problem1.m() * problem1.n();
|
||||
|
||||
total_elements_Q += elements_Q;
|
||||
total_elements_K += elements_K;
|
||||
total_elements_P += elements_P;
|
||||
total_elements_V += elements_V;
|
||||
total_elements_O += elements_O;
|
||||
}
|
||||
}
|
||||
|
||||
problem_sizes_device0.reset(problem_count());
|
||||
@ -649,15 +663,11 @@ private:
|
||||
|
||||
bool passed = true;
|
||||
|
||||
for (int32_t i = 0; i < problem_count(); ++i) {
|
||||
cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i);
|
||||
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i);
|
||||
|
||||
LayoutQ layout_Q(ldq_host.at(i));
|
||||
LayoutK layout_K(ldk_host.at(i));
|
||||
LayoutP layout_P(ldp_host.at(i));
|
||||
LayoutV layout_V(ldv_host.at(i));
|
||||
LayoutO layout_O(ldo_host.at(i));
|
||||
for (int32_t b = 0; b < options.batch_size; ++b) {
|
||||
int32_t i = b * options.head_number;
|
||||
// Problem size is the same for all heads
|
||||
cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(b * options.head_number);
|
||||
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(b * options.head_number);
|
||||
|
||||
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
||||
MatrixCoord extent_K{problem0.k(), problem0.n()};
|
||||
@ -665,114 +675,121 @@ private:
|
||||
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
||||
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
||||
|
||||
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
||||
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||
cutlass::TensorView<ElementP, LayoutP> view_P(block_P.get() + offset_P.at(i), layout_P, extent_P);
|
||||
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
|
||||
|
||||
cutlass::DeviceAllocation<ElementP> block_Ref(layout_P.capacity(extent_P));
|
||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_device(block_Ref.get(), layout_P, extent_P);
|
||||
|
||||
LayoutO layout_O(ldo_host.at(i));
|
||||
std::vector<ElementO> matrix_O(layout_O.capacity(extent_O));
|
||||
cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size());
|
||||
cutlass::DeviceAllocation<ElementO> block_Ref_O(layout_O.capacity(extent_O));
|
||||
cutlass::TensorView<ElementO, LayoutO> view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O);
|
||||
|
||||
// Reference GEMM
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementQ, LayoutQ,
|
||||
ElementK, LayoutK,
|
||||
ElementP, LayoutP,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem0,
|
||||
ElementAccumulator(options.alpha0),
|
||||
view_Q,
|
||||
Attention::MM0::Mma::kTransformA,
|
||||
view_K,
|
||||
Attention::MM0::Mma::kTransformB,
|
||||
ElementAccumulator(options.beta),
|
||||
view_P,
|
||||
view_Ref_device,
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
for (int32_t h = 0; h < options.head_number; ++h) {
|
||||
i = h + b * options.head_number;
|
||||
|
||||
// Compute softmax for P. We need to explicitly compute softmax
|
||||
// over P because softmax is fused to the second GEMM in the
|
||||
// profiled implementation.
|
||||
std::vector<ElementP> matrix_Ref(layout_P.capacity(extent_P));
|
||||
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size());
|
||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_host(matrix_Ref.data(), layout_P, extent_P);
|
||||
std::vector<ElementNorm> vector_Norm_Ref(problem0.m());
|
||||
std::vector<ElementSum> vector_Sum_Ref(problem0.m());
|
||||
LayoutQ layout_Q(ldq_host.at(i));
|
||||
LayoutK layout_K(ldk_host.at(i));
|
||||
LayoutP layout_P(ldp_host.at(i));
|
||||
LayoutV layout_V(ldv_host.at(i));
|
||||
|
||||
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
||||
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
||||
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
|
||||
cutlass::TensorView<ElementO, LayoutO> view_Ref_O_device(block_Ref_O.get() + offset_O.at(i) - offset_O.at(b * options.head_number), layout_O, extent_O);
|
||||
|
||||
// Compute softmax for referece matrix
|
||||
for (int m = 0; m < problem0.m(); m++) {
|
||||
int n_dim_row = n_dim;
|
||||
if (options.causal) {
|
||||
n_dim_row = std::min(m + 1, n_dim);
|
||||
}
|
||||
ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0}));
|
||||
for (int n = 1; n < n_dim_row; n++) {
|
||||
max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})));
|
||||
}
|
||||
cutlass::DeviceAllocation<ElementP> block_Ref_P(layout_P.capacity(extent_P));
|
||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_P_device(block_Ref_P.get(), layout_P, extent_P);
|
||||
|
||||
vector_Norm_Ref.at(m) = ElementNorm(max);
|
||||
// Reference GEMM
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementQ, LayoutQ,
|
||||
ElementK, LayoutK,
|
||||
ElementP, LayoutP,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem0,
|
||||
ElementAccumulator(options.alpha0),
|
||||
view_Q,
|
||||
Attention::MM0::Mma::kTransformA,
|
||||
view_K,
|
||||
Attention::MM0::Mma::kTransformB,
|
||||
ElementAccumulator(options.beta),
|
||||
view_Ref_P_device,
|
||||
view_Ref_P_device,
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
ElementSoftmaxCompute sum = ElementSoftmaxCompute();
|
||||
for (int n = 0; n < n_dim_row; n++) {
|
||||
sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max );
|
||||
}
|
||||
ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum);
|
||||
// Compute softmax for P. We need to explicitly compute softmax
|
||||
// over P because softmax is fused to the second GEMM in the
|
||||
// profiled implementation.
|
||||
std::vector<ElementP> matrix_Ref(layout_P.capacity(extent_P));
|
||||
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref_P.get(), matrix_Ref.size());
|
||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_host(matrix_Ref.data(), layout_P, extent_P);
|
||||
std::vector<ElementNorm> vector_Norm_Ref(problem0.m());
|
||||
std::vector<ElementSum> vector_Sum_Ref(problem0.m());
|
||||
|
||||
vector_Sum_Ref.at(m) = ElementSum(inv_sum);
|
||||
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
||||
|
||||
for (int n = 0; n < n_dim_row; n++) {
|
||||
view_Ref_host.ref().at({m, n}) = ElementP(
|
||||
std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum
|
||||
);
|
||||
}
|
||||
// Mask out the rest of the attention matrix
|
||||
for (int n = n_dim_row; n < n_dim; ++n) {
|
||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||
}
|
||||
}
|
||||
|
||||
// when not using mask, problem_real and problem share the same sizes
|
||||
if (options.use_mask) {
|
||||
// Compute softmax for reference matrix
|
||||
for (int m = 0; m < problem0.m(); m++) {
|
||||
for (int n = n_dim; n < problem0.n(); n++) {
|
||||
int n_dim_row = n_dim;
|
||||
if (options.causal) {
|
||||
n_dim_row = std::min(m + 1, n_dim);
|
||||
}
|
||||
ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0}));
|
||||
for (int n = 1; n < n_dim_row; n++) {
|
||||
max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})));
|
||||
}
|
||||
|
||||
vector_Norm_Ref.at(m) = ElementNorm(max);
|
||||
|
||||
ElementSoftmaxCompute sum = ElementSoftmaxCompute();
|
||||
for (int n = 0; n < n_dim_row; n++) {
|
||||
sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max );
|
||||
}
|
||||
ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum);
|
||||
|
||||
vector_Sum_Ref.at(m) = ElementSum(inv_sum);
|
||||
|
||||
for (int n = 0; n < n_dim_row; n++) {
|
||||
view_Ref_host.ref().at({m, n}) = ElementP(
|
||||
std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum
|
||||
);
|
||||
}
|
||||
// Mask out the rest of the attention matrix
|
||||
for (int n = n_dim_row; n < n_dim; ++n) {
|
||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||
}
|
||||
}
|
||||
|
||||
// when not using mask, problem_real and problem share the same sizes
|
||||
if (options.use_mask) {
|
||||
for (int m = 0; m < problem0.m(); m++) {
|
||||
for (int n = n_dim; n < problem0.n(); n++) {
|
||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device(block_Ref_P.get(), matrix_Ref.data(), matrix_Ref.size());
|
||||
|
||||
// Reference GEMM
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementP, LayoutP,
|
||||
ElementV, LayoutV,
|
||||
ElementO, LayoutO,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem1,
|
||||
ElementAccumulator(options.alpha1),
|
||||
view_Ref_P_device,
|
||||
Attention::MM0::Mma::kTransformA,
|
||||
view_V,
|
||||
Attention::MM0::Mma::kTransformB,
|
||||
ElementAccumulator(options.beta),
|
||||
view_Ref_O_device,
|
||||
view_Ref_O_device,
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size());
|
||||
|
||||
// Reference GEMM
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementP, LayoutP,
|
||||
ElementV, LayoutV,
|
||||
ElementO, LayoutO,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem1,
|
||||
ElementAccumulator(options.alpha1),
|
||||
view_P,
|
||||
Attention::MM0::Mma::kTransformA,
|
||||
view_V,
|
||||
Attention::MM0::Mma::kTransformB,
|
||||
ElementAccumulator(options.beta),
|
||||
view_Ref_O_device,
|
||||
view_Ref_O_device,
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
// Copy to host memory
|
||||
cutlass::TensorView<ElementP, LayoutP> view_Ref(matrix_Ref.data(), layout_P, extent_P);
|
||||
|
||||
std::vector<ElementO> matrix_O(layout_O.capacity(extent_O));
|
||||
cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size());
|
||||
std::vector<ElementO> matrix_Ref_O(layout_O.capacity(extent_O));
|
||||
cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size());
|
||||
|
||||
@ -788,7 +805,7 @@ private:
|
||||
passed = passed && verified_O;
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl;
|
||||
std::cerr << "\n***\nError - problem " << i << " (batch " << b << ") failed the QA check\n***\n" << std::endl;
|
||||
|
||||
if (!verified_O) {
|
||||
std::cout << "Final matrix output is incorrect" << std::endl;
|
||||
@ -831,6 +848,8 @@ public:
|
||||
// p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr();
|
||||
// }
|
||||
|
||||
p.scale = options.alpha0;
|
||||
|
||||
p.num_heads = options.head_number;
|
||||
p.num_batches = options.batch_size;
|
||||
p.head_dim = options.head_size;
|
||||
@ -839,18 +858,16 @@ public:
|
||||
p.num_keys = options.seq_length_kv;
|
||||
p.causal = options.causal;
|
||||
|
||||
// TODO: This might overflow for big tensors
|
||||
// All tensors are in BMHK shapes
|
||||
p.q_strideH = options.head_size;
|
||||
p.k_strideH = options.head_size;
|
||||
p.v_strideH = options.head_size_v;
|
||||
p.q_strideM = int32_t(ldq_host[0]);
|
||||
p.k_strideM = int32_t(ldk_host[0]);
|
||||
p.v_strideM = int32_t(ldv_host[0]);
|
||||
p.q_strideH = p.q_strideM * options.seq_length;
|
||||
p.k_strideH = p.k_strideM * options.seq_length_kv;
|
||||
p.v_strideH = p.v_strideM * options.seq_length_kv;
|
||||
p.o_strideH = options.head_size_v * options.seq_length;
|
||||
p.q_strideB = p.q_strideH * options.head_number;
|
||||
p.k_strideB = p.k_strideH * options.head_number;
|
||||
p.v_strideB = p.v_strideH * options.head_number;
|
||||
p.o_strideB = options.head_size_v * options.seq_length * options.head_number;
|
||||
p.q_strideB = p.q_strideM * options.seq_length;
|
||||
p.k_strideB = p.k_strideM * options.seq_length_kv;
|
||||
p.v_strideB = p.v_strideM * options.seq_length_kv;
|
||||
}
|
||||
|
||||
// launch kernel :)
|
||||
|
@ -921,6 +921,7 @@ public:
|
||||
ldv.get(),
|
||||
ldo.get(),
|
||||
options.causal,
|
||||
options.alpha0,
|
||||
options.problem_sizes1.data()
|
||||
);
|
||||
|
||||
|
@ -36,20 +36,20 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Some helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#define DISPATCH_TYPES(tensor, func) \
|
||||
{ \
|
||||
if (query.scalar_type() == at::ScalarType::Float) { \
|
||||
using scalar_t = float; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::Half) { \
|
||||
using scalar_t = cutlass::half_t; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
|
||||
using scalar_t = cutlass::bfloat16_t; \
|
||||
func(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
|
||||
} \
|
||||
#define DISPATCH_TYPES(tensor, func) \
|
||||
{ \
|
||||
if (query.scalar_type() == at::ScalarType::Float) { \
|
||||
using scalar_t = float; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::Half) { \
|
||||
using scalar_t = cutlass::half_t; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
|
||||
using scalar_t = cutlass::bfloat16_t; \
|
||||
func(); \
|
||||
} else { \
|
||||
XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
||||
@ -77,26 +77,27 @@
|
||||
using ArchTag = cutlass::arch::Sm50; \
|
||||
func(); \
|
||||
} else { \
|
||||
TORCH_CHECK( \
|
||||
XFORMERS_CHECK( \
|
||||
false, \
|
||||
"Your device is too old. We require compute capability >= 50"); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK(TENSOR.is_contiguous());
|
||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
XFORMERS_CHECK(TENSOR.is_contiguous());
|
||||
|
||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK( \
|
||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
XFORMERS_CHECK( \
|
||||
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
#ifdef TORCH_CHECK
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
XFORMERS_CHECK( \
|
||||
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
#define XFORMERS_CHECK TORCH_CHECK
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
@ -108,6 +109,7 @@
|
||||
return false; \
|
||||
}
|
||||
#else
|
||||
#include <iostream>
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||
@ -120,74 +122,25 @@
|
||||
}
|
||||
#endif
|
||||
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK( \
|
||||
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
|
||||
#B " overflows"); \
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
XFORMERS_CHECK( \
|
||||
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
|
||||
}
|
||||
|
||||
namespace gemm_kernel_utils {
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
template <typename scalar_t>
|
||||
struct TypeTraits;
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::half_t> {
|
||||
using scalar_t = cutlass::half_t;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Half;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
||||
(scalar_t*)(tensor.data_ptr()),
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::bfloat16_t> {
|
||||
using scalar_t = cutlass::bfloat16_t;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::BFloat16;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
||||
(scalar_t*)(tensor.data_ptr()),
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<float> {
|
||||
using scalar_t = float;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Float;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return tensor.packed_accessor32<scalar_t, nDim>();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename integer>
|
||||
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
|
||||
return (n + m - 1) / m;
|
||||
}
|
||||
|
||||
template <typename integer>
|
||||
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
|
||||
return ((n + m - 1) / m) * m;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
|
||||
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
|
||||
|
@ -311,9 +311,9 @@ class PredicatedTileIteratorPrefetch {
|
||||
// on windows using unsigned long here gives the error
|
||||
// error: asm operand type size(4) does not match
|
||||
// type/size implied by constraint 'l'
|
||||
uint64_t addr = (uint64_t)(
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
|
||||
uint64_t addr = (uint64_t)((void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn /
|
||||
kElementsPerAccess]);
|
||||
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,53 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_iterator_from_smem.h"
|
||||
|
||||
template <typename WarpIterator>
|
||||
struct TransposeWarpIterator {
|
||||
using Iterator = char;
|
||||
static bool constexpr kSupportsTranspose = false;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Operand identity
|
||||
cutlass::gemm::Operand Operand,
|
||||
/// Data type of A elements
|
||||
typename Element,
|
||||
bool kTranspose>
|
||||
struct TransposeWarpIterator<
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
|
||||
using Iterator =
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
|
||||
static bool constexpr kSupportsTranspose = true;
|
||||
};
|
@ -0,0 +1,278 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Inspired from
|
||||
"cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
|
||||
operands from a RowMajor shared-memory layout into registers to use by A100
|
||||
TensorCores.
|
||||
|
||||
The difference with "mma_tensor_op_tile_access_iterator.h" is that:
|
||||
(1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
|
||||
faster) (2) We support to transpose the operand (eg read `A.transpose()` when
|
||||
the shared memory holds `A`)
|
||||
|
||||
This is only implemented for the specific shapes.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
template <
|
||||
/// Operand identity
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
bool kTranspose = false>
|
||||
class WarpIteratorFromSmem {
|
||||
public:
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = cutlass::MatrixShape<32, 32>;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
/// Basic check
|
||||
static_assert(
|
||||
kOperand == Operand::kA || kOperand == Operand::kB,
|
||||
"WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = cutlass::MatrixShape<16, 8>;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
static int const kOpDelta = 1;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Number of elements accessed per Shared Memory load
|
||||
static int const kElementsPerAccess =
|
||||
(sizeof_bits<Element>::value >= 32 ? 1
|
||||
: 32 / sizeof_bits<Element>::value);
|
||||
|
||||
using InstructionCount = MatrixShape<
|
||||
Shape::kRow / InstructionShape::kRow,
|
||||
Shape::kColumn / InstructionShape::kColumn>;
|
||||
|
||||
static int const kIterations = (kOperand == Operand::kA)
|
||||
? InstructionCount::kColumn
|
||||
: InstructionCount::kRow;
|
||||
|
||||
public:
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
(kOperand == Operand::kA)
|
||||
? (Shape::kRow* InstructionShape::kColumn / kThreads)
|
||||
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
|
||||
|
||||
/// Memory access type
|
||||
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
using AccessType = Array<unsigned, 4>;
|
||||
|
||||
static int constexpr kWarpShapeDivisibleInner =
|
||||
(kOperand == Operand::kA ? InstructionShape::kColumn
|
||||
: InstructionShape::kRow);
|
||||
static int constexpr kAccessesInner =
|
||||
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||
|
||||
private:
|
||||
/// Underlying tensor reference
|
||||
TensorRef ref_;
|
||||
|
||||
/// Origin
|
||||
MatrixCoord origin_;
|
||||
|
||||
/// Iterations in a tile
|
||||
int iterations_;
|
||||
|
||||
public:
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
|
||||
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
||||
: ref_(ref), iterations_(0) {
|
||||
int ldsm_vec_num = (lane_id >> 3);
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id % 8, 0);
|
||||
static_assert(
|
||||
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
|
||||
"");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
|
||||
++inst_m_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
|
||||
++access_m_idx) {
|
||||
int access_idx = access_m_idx +
|
||||
kTilesPerInstruction *
|
||||
(inner_idx + kAccessesInner * inst_m_idx);
|
||||
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
|
||||
if (access_idx == ldsm_vec_num) {
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
origin_ = MatrixCoord(0, lane_id % 8);
|
||||
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
|
||||
++inst_n_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
|
||||
|
||||
MatrixCoord offset(
|
||||
inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
|
||||
|
||||
if (access_idx == ldsm_vec_num) {
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole
|
||||
/// tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
|
||||
TensorCoord coord_offset(
|
||||
tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
if (kTranspose) {
|
||||
coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
|
||||
}
|
||||
origin_ += coord_offset;
|
||||
|
||||
ref_.add_coord_offset(coord_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, 1});
|
||||
} else {
|
||||
add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
iterations_ = 0;
|
||||
}
|
||||
|
||||
/// increase iterations in a tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem& operator++() {
|
||||
iterations_++;
|
||||
|
||||
if (iterations_ >= kIterations)
|
||||
advance();
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const {
|
||||
AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
|
||||
using LoadLayout = typename platform::
|
||||
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
||||
|
||||
MatrixCoord offset;
|
||||
if (kOperand == Operand::kA) {
|
||||
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
|
||||
} else {
|
||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||
}
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||
access_ptr[0], ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -29,15 +29,6 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/library.h>
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
@ -137,6 +128,9 @@ struct AttentionKernel {
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
|
||||
// Scale
|
||||
accum_t scale;
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
@ -154,18 +148,15 @@ struct AttentionKernel {
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t o_strideH;
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int64_t o_strideB;
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
|
||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||
return head_dim_value;
|
||||
return head_dim_value * num_heads;
|
||||
}
|
||||
|
||||
// Moves pointers to what we should process
|
||||
// Returns "false" if there is no work to do
|
||||
CUTLASS_DEVICE bool advance_to_block() {
|
||||
@ -195,9 +186,9 @@ struct AttentionKernel {
|
||||
query_ptr += batch_id * q_strideB;
|
||||
key_ptr += batch_id * k_strideB;
|
||||
value_ptr += batch_id * v_strideB;
|
||||
output_ptr += batch_id * o_strideB;
|
||||
output_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += batch_id * o_strideB;
|
||||
output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||
}
|
||||
q_start = 0;
|
||||
k_start = 0;
|
||||
@ -208,11 +199,11 @@ struct AttentionKernel {
|
||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
head_id * head_dim_value;
|
||||
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
head_id * head_dim_value;
|
||||
} else {
|
||||
// Accumulate directly in the destination buffer (eg for f32)
|
||||
output_accum_ptr = (accum_t*)output_ptr;
|
||||
@ -652,7 +643,7 @@ struct AttentionKernel {
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
||||
p.scale);
|
||||
}));
|
||||
}));
|
||||
|
||||
@ -858,93 +849,3 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
}
|
||||
AK::attention_kernel(p);
|
||||
}
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched(typename AK::Params params);
|
||||
|
||||
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
|
||||
template <> \
|
||||
__global__ void __launch_bounds__( \
|
||||
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
|
||||
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
|
||||
using Kernel = __VA_ARGS__;
|
||||
#define _ATTENTION_KERNEL_FORWARD_END() }
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
|
||||
#else
|
||||
#define __CUDA_ARCH_OR_ZERO__ 0
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
if (!p.advance_to_block()) { \
|
||||
return; \
|
||||
} \
|
||||
Kernel::attention_kernel(p); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
printf( \
|
||||
"FATAL: this function is for sm%d, but was built for sm%d\n", \
|
||||
int(ARCH), \
|
||||
int(__CUDA_ARCH_OR_ZERO__)); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
// All kernels are disabled by default
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
||||
|
||||
// Enable the right one based on __CUDA_ARCH__
|
||||
#ifndef __CUDA_ARCH__
|
||||
#elif __CUDA_ARCH__ < 500
|
||||
#error "Need cuda arch at least 5.0"
|
||||
#elif __CUDA_ARCH__ < 700
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 750
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ >= 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
||||
#endif
|
||||
|
@ -57,8 +57,8 @@
|
||||
#include "epilogue_thread_apply_logsumexp.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "iterators/make_residual_last.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "iterators/transpose_warp_iterator.h"
|
||||
#include "iterators/warp_iterator_from_smem.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -560,20 +560,14 @@ template <
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages_,
|
||||
int kMaxK_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
Policy1_,
|
||||
Stages_> {
|
||||
class MmaMultistageFromSharedMemory
|
||||
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
Policy1_,
|
||||
Stages_>;
|
||||
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
@ -1035,16 +1029,39 @@ template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename RegularWarpIterator,
|
||||
typename Policy>
|
||||
typename Policy,
|
||||
typename Enable = void>
|
||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||
|
||||
// TensorOp - Ampere
|
||||
// TensorOp - Ampere half
|
||||
template <typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||
|
||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element>;
|
||||
};
|
||||
|
||||
// TensorOp - Ampere f32
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
@ -1099,7 +1116,10 @@ struct DefaultWarpIteratorAFromSharedMemory<
|
||||
};
|
||||
|
||||
// Converts a "regular" Mma into their counterpart from shared memory
|
||||
template <typename Mma_, typename AccumulatorSharedStorage>
|
||||
template <
|
||||
typename Mma_,
|
||||
typename AccumulatorSharedStorage,
|
||||
bool kTransposeA = false>
|
||||
struct DefaultMmaFromSharedMemory;
|
||||
|
||||
// Mma pipelined
|
||||
@ -1130,7 +1150,8 @@ template <
|
||||
typename TransformA_,
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB_,
|
||||
typename AccumulatorSharedStorage_>
|
||||
typename AccumulatorSharedStorage_,
|
||||
bool kTransposeA>
|
||||
struct DefaultMmaFromSharedMemory<
|
||||
MmaPipelined<
|
||||
Shape_,
|
||||
@ -1143,7 +1164,8 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
TransformA_,
|
||||
TransformB_>,
|
||||
AccumulatorSharedStorage_> {
|
||||
AccumulatorSharedStorage_,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
|
||||
@ -1163,6 +1185,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||
using ArchMmaOperator = typename Policy_::Operator;
|
||||
|
||||
static constexpr bool kIsTransposedA = false;
|
||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
@ -1214,7 +1237,8 @@ template <
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
typename AccumulatorSharedStorage_>
|
||||
typename AccumulatorSharedStorage_,
|
||||
bool kTransposeA>
|
||||
struct DefaultMmaFromSharedMemory<
|
||||
MmaMultistage<
|
||||
Shape_,
|
||||
@ -1229,7 +1253,8 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
Stages,
|
||||
SharedMemoryClear>,
|
||||
AccumulatorSharedStorage_> {
|
||||
AccumulatorSharedStorage_,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
using RegularMma = MmaMultistage<
|
||||
@ -1248,13 +1273,22 @@ struct DefaultMmaFromSharedMemory<
|
||||
|
||||
using WarpShape = typename Policy_::Operator::Shape;
|
||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename RegularMma::Operator::IteratorA,
|
||||
Policy_>::WarpIterator;
|
||||
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
||||
static constexpr bool kIsTransposedA =
|
||||
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
||||
using WarpIteratorA = typename platform::conditional<
|
||||
kIsTransposedA,
|
||||
typename WarpIteratorTranspose::Iterator,
|
||||
WarpIteratorA_>::type;
|
||||
|
||||
static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN;
|
||||
static int constexpr kMaxK = kIsTransposedA
|
||||
? AccumulatorSharedStorage_::Shape::kM
|
||||
: AccumulatorSharedStorage_::Shape::kN;
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStagesMax =
|
||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||
@ -1274,7 +1308,8 @@ struct DefaultMmaFromSharedMemory<
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
kStages>;
|
||||
kStages,
|
||||
kMaxK>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user