cutlass/examples/42_fused_multi_head_attention/fused_multihead_attention.cu
dan_the_3rd 4db6a6140e
ex42: Fused MHA imported from xFormers (#662)
* ex42: Fused MHA imported from xFormers

* Remove std:: references

* Support K>128 in the example

* Support causal option

* Support different head size for V, and different seqlength for KV

* Update FLOPS counter

* Remove bit_cast

* fix build: Replace M_LOG2E

* Add doc

* Revert "Remove bit_cast"

This reverts commit 9662fa86bb7c57c1a015ac0bf52cb52940fbbf80.

* Explicit casts to int32_t for windows build

Co-authored-by: danthe3rd <danthe3rd>
2022-10-17 10:49:33 -04:00

1093 lines
36 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Attention Example.
This workload computes a fused multi head attention.
Because it keeps the attention matrix in shared memory, it's both faster and
uses less global memory.
This is based on `"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_,
and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" <https://arxiv.org/abs/2205.14135>`_.
Algorithm:
In short, we can compute the output incrementally in blocks of size B,
we just need to divide the final result by the sum of all coefficients in
the softmax (which we compute incrementally) with the following pseudo-code:
```
s_prime = torch.zeros([num_queries, B])
O = torch.zeros([num_queries, head_size_v])
for i in range(0, K.shape[0], B):
si = exp((Q . K[i * B:(i+1) * B].t) * scale)
sum_coefs += attn_unscaled.sum(-1)
O += si . V[i * B:(i+1) * B]
O = O / s_prime
```
In practice, and for numerical stability reasons,
we also substract the maximum so far (`mi`) before doing
the exponential. When we encounter new keys, the maximum
used to compute O so far (`m_prime`) can differ from the
current maximum, so we update O before accumulating with
```
O = O * exp(m_prime - mi)
m_prime = mi
```
Implementation details:
- `si` is stored in shared memory between the 2 back to back gemms
- we keep and accumulate the output
directly in registers if we can (`head_size_v <= 128`).
Otherwise, we store it & accumulate in global memory (slower)
- blocks are parallelized across the batch dimension, the number
of heads, and the query sequence size
Examples:
# Run an attention example with default setup
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention
# Run an attention example with custom setup
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#include <unordered_map>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "cutlass/fast_math.h"
#include "kernel_forward.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
bool use_mask;
bool causal;
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
std::vector<cutlass::gemm::GemmCoord> problem_sizes0_real;
std::vector<cutlass::gemm::GemmCoord> problem_sizes1_real;
int alignment;
int head_number;
int batch_size;
int head_size;
int head_size_v;
int seq_length;
int seq_length_kv;
int iterations;
// alpha0, alpha1 and beta are fixed
// in this multi-head attention example
float alpha0;
float alpha1;
float beta;
//
// Methods
//
Options():
help(false),
error(false),
alignment(1),
reference_check(true),
head_number(12),
batch_size(16),
head_size(64),
head_size_v(64),
seq_length(1024),
seq_length_kv(1024),
use_mask(false),
iterations(20),
causal(false)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("alignment", alignment, 1);
cmd.get_cmd_line_argument("head_number", head_number, 12);
cmd.get_cmd_line_argument("batch_size", batch_size, 16);
cmd.get_cmd_line_argument("head_size", head_size, 64);
cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size);
cmd.get_cmd_line_argument("seq_length", seq_length, 1024);
cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length);
cmd.get_cmd_line_argument("use_mask", use_mask, false);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("causal", causal, true);
randomize_problems();
}
void randomize_problems() {
int problem_count = head_number * batch_size;
problem_sizes0.reserve(problem_count);
problem_sizes1.reserve(problem_count);
// When using mask, the original inputs are not padded
// and we need to save these info.
if (use_mask) {
problem_sizes0_real.reserve(problem_count);
problem_sizes1_real.reserve(problem_count);
}
for (int i = 0; i < batch_size; ++i) {
// problems belonging to the same batch share the same seq len
int m_real = seq_length; // (rand() % seq_length);
int mkv_real = seq_length_kv; // (rand() % seq_length_kv);
int m = (m_real + alignment - 1) / alignment * alignment;
int mkv = (mkv_real + alignment - 1) / alignment * alignment;
int k0 = head_size;
int k1 = head_size_v;
for (int j = 0; j < head_number; ++j) {
cutlass::gemm::GemmCoord problem0(m, mkv, k0);
cutlass::gemm::GemmCoord problem1(m, k1, mkv);
problem_sizes0.push_back(problem0);
problem_sizes1.push_back(problem1);
if (use_mask) {
cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0);
cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real);
problem_sizes0_real.push_back(problem0_real);
problem_sizes1_real.push_back(problem1_real);
}
}
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "42_fused_multi_head_attention\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --head_number=<int> Head number in multi-head attention (default: --head_number=12)\n"
<< " --batch_size=<int> Batch size in multi-head attention (default: --batch_size=16)\n"
<< " --head_size=<int> Head size in multi-head attention (default: --head_size=64)\n"
<< " --head_size_v=<int> Head size in multi-head attention for V (default: --head_size_v=head_size)\n"
<< " --seq_length=<int> Sequence length in multi-head attention for Q (default: --seq_length=1024)\n"
<< " --seq_length_kv=<int> Sequence length in multi-head attention for K/V(default: --seq_length_kv=seq_length)\n"
<< " --use_mask=<bool> If true, performs padding-like masking in softmax.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n"
<< " --causal=<bool> If true, uses causal masking.\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fops = int64_t();
for (int i = 0; i < problem_sizes0.size(); ++i) {
auto const& problem0 = problem_sizes0[i];
auto const& problem1 = problem_sizes1[i];
for (int row = 0; row < problem0.m(); ++row) {
int num_cols0 = problem0.n();
if (causal) {
num_cols0 = std::min(row + 1, num_cols0);
}
// P <- Q . K_t
fops += 2 * num_cols0 * problem0.k();
// P <- exp(P - max(P))
fops += 2 * num_cols0;
// S <- sum(P)
fops += num_cols0 - 1;
// O <- P . V
fops += 2 * num_cols0 * problem1.n();
// O <- O / S
fops += num_cols0 * problem1.n();
}
}
return double(fops) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Attention>
class TestbedAttention {
public:
//
// Type definitions
//
using ElementQ = typename Attention::scalar_t;
using ElementK = typename Attention::scalar_t;
using ElementP = typename Attention::accum_t;
using ElementAccumulator = typename Attention::accum_t;
using ElementV = typename Attention::scalar_t;
using ElementO = typename Attention::output_t;
using ElementCompute = typename Attention::accum_t;
using ElementNorm = typename Attention::accum_t;
using ElementSum = typename Attention::accum_t;
using ElementSoftmaxCompute = typename Attention::accum_t;
using LayoutQ = cutlass::layout::RowMajor;
using LayoutK = cutlass::layout::RowMajor;
using LayoutK_T = cutlass::layout::ColumnMajor; // transposed
using LayoutP = cutlass::layout::RowMajor;
using LayoutV = cutlass::layout::RowMajor;
using LayoutO = cutlass::layout::RowMajor;
using MatrixCoord = typename LayoutP::TensorCoord;
private:
//
// Data members
//
Options & options;
/// Initialization
cutlass::Distribution::Kind init_Q;
cutlass::Distribution::Kind init_K;
cutlass::Distribution::Kind init_P;
cutlass::Distribution::Kind init_V;
cutlass::Distribution::Kind init_O;
uint32_t seed;
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device0;
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device1;
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device0_real;
std::vector<int64_t> offset_Q;
std::vector<int64_t> offset_K;
std::vector<int64_t> offset_P;
std::vector<int64_t> offset_V;
std::vector<int64_t> offset_O;
std::vector<int64_t> ldq_host;
std::vector<int64_t> ldk_host;
std::vector<int64_t> ldp_host;
std::vector<int64_t> ldv_host;
std::vector<int64_t> ldo_host;
std::vector<int64_t> seqlen_host;
cutlass::DeviceAllocation<int64_t> ldq;
cutlass::DeviceAllocation<int64_t> ldk;
cutlass::DeviceAllocation<int64_t> ldp;
cutlass::DeviceAllocation<int64_t> ldv;
cutlass::DeviceAllocation<int64_t> ldo;
cutlass::DeviceAllocation<int64_t> seqlen;
cutlass::DeviceAllocation<ElementQ> block_Q;
cutlass::DeviceAllocation<ElementK> block_K;
cutlass::DeviceAllocation<ElementP> block_P;
cutlass::DeviceAllocation<ElementV> block_V;
cutlass::DeviceAllocation<ElementO> block_O;
cutlass::DeviceAllocation<ElementNorm> block_Norm;
cutlass::DeviceAllocation<ElementSum> block_Sum;
cutlass::DeviceAllocation<int64_t> offset_P_Device;
cutlass::DeviceAllocation<ElementQ *> ptr_Q;
cutlass::DeviceAllocation<ElementK *> ptr_K;
cutlass::DeviceAllocation<ElementP *> ptr_P;
cutlass::DeviceAllocation<ElementV *> ptr_V;
cutlass::DeviceAllocation<ElementO *> ptr_O;
public:
//
// Methods
//
TestbedAttention(
Options &options_,
cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform,
uint32_t seed_ = 3080
):
options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { }
int problem_count() const {
return (options.head_number * options.batch_size);
}
private:
/// Helper to initialize a tensor view
template <typename Element>
void initialize_tensor_(
Element *ptr,
size_t capacity,
cutlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<ElementP>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 8;
scope_min = -8;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
ptr, capacity, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::device::BlockFillRandomGaussian(
ptr, capacity, seed, Element(), Element(0.5f));
}
else if (dist_kind == cutlass::Distribution::Sequential) {
// Fill with increasing elements
cutlass::reference::device::BlockFillSequential(
ptr, capacity, Element(1), Element());
}
else {
// Fill with all 1s
cutlass::reference::device::BlockFillSequential(
ptr, capacity, Element(), Element(1));
}
}
/// Initializes data structures
void initialize_() {
//
// Set scalors for the mha example
//
options.alpha0 = 1.0f / sqrt(float(options.head_size));
options.alpha1 = 1.0f;
options.beta = 0;
//
// Choose random problem sizes
//
// construct a few problems of random sizes
srand(seed);
int64_t total_elements_Q = 0;
int64_t total_elements_K = 0;
int64_t total_elements_P = 0;
int64_t total_elements_V = 0;
int64_t total_elements_O = 0;
ldq_host.resize(problem_count());
ldk_host.resize(problem_count());
ldp_host.resize(problem_count());
ldv_host.resize(problem_count());
ldo_host.resize(problem_count());
seqlen_host.resize(problem_count());
for (int32_t i = 0; i < problem_count(); ++i) {
auto problem0 = options.problem_sizes0.at(i);
auto problem1 = options.problem_sizes1.at(i);
ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0);
ldk_host.at(i) = LayoutK::packed({problem0.n(), problem0.k()}).stride(0);
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0);
ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0);
// m = n for attention problems.
seqlen_host.at(i) = problem0.m();
offset_Q.push_back(total_elements_Q);
offset_K.push_back(total_elements_K);
offset_P.push_back(total_elements_P);
offset_V.push_back(total_elements_V);
offset_O.push_back(total_elements_O);
int64_t elements_Q = problem0.m() * problem0.k();
int64_t elements_K = problem0.k() * problem0.n();
int64_t elements_P = problem0.m() * problem0.n();
int64_t elements_V = problem1.k() * problem1.n();
int64_t elements_O = problem1.m() * problem1.n();
total_elements_Q += elements_Q;
total_elements_K += elements_K;
total_elements_P += elements_P;
total_elements_V += elements_V;
total_elements_O += elements_O;
}
problem_sizes_device0.reset(problem_count());
problem_sizes_device1.reset(problem_count());
problem_sizes_device0.copy_from_host(options.problem_sizes0.data());
problem_sizes_device1.copy_from_host(options.problem_sizes1.data());
if (options.use_mask) {
problem_sizes_device0_real.reset(problem_count());
problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data());
}
ldq.reset(problem_count());
ldk.reset(problem_count());
ldp.reset(problem_count());
ldv.reset(problem_count());
ldo.reset(problem_count());
seqlen.reset(problem_count());
ldq.copy_from_host(ldq_host.data());
ldk.copy_from_host(ldk_host.data());
ldp.copy_from_host(ldp_host.data());
ldv.copy_from_host(ldv_host.data());
ldo.copy_from_host(ldo_host.data());
seqlen.copy_from_host(seqlen_host.data());
//
// Assign pointers
//
block_Q.reset(total_elements_Q);
block_K.reset(total_elements_K);
block_P.reset(total_elements_P);
block_V.reset(total_elements_V);
block_O.reset(total_elements_O);
offset_P_Device.reset(problem_count());
// sync offset with device
cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size());
std::vector<ElementQ *> ptr_Q_host(problem_count());
std::vector<ElementK *> ptr_K_host(problem_count());
std::vector<ElementP *> ptr_P_host(problem_count());
std::vector<ElementV *> ptr_V_host(problem_count());
std::vector<ElementO *> ptr_O_host(problem_count());
std::vector<ElementNorm *> ptr_norm_host(problem_count());
std::vector<ElementSum *> ptr_sum_host(problem_count());
for (int32_t i = 0; i < problem_count(); ++i) {
ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i);
ptr_K_host.at(i) = block_K.get() + offset_K.at(i);
ptr_P_host.at(i) = block_P.get() + offset_P.at(i);
ptr_V_host.at(i) = block_V.get() + offset_V.at(i);
ptr_O_host.at(i) = block_O.get() + offset_O.at(i);
}
ptr_Q.reset(problem_count());
ptr_Q.copy_from_host(ptr_Q_host.data());
ptr_K.reset(problem_count());
ptr_K.copy_from_host(ptr_K_host.data());
ptr_P.reset(problem_count());
ptr_P.copy_from_host(ptr_P_host.data());
ptr_V.reset(problem_count());
ptr_V.copy_from_host(ptr_V_host.data());
ptr_O.reset(problem_count());
ptr_O.copy_from_host(ptr_O_host.data());
//
// Initialize the problems of the workspace
//
initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1);
initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2);
initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3);
}
template<typename Element>
bool verify_tensor_(std::vector<Element> vector_Input, \
std::vector<Element> vector_Input_Ref,
int64_t verify_length = -1) {
int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size();
size = (verify_length == -1) ? size : verify_length;
// 0.05 for absolute error
float abs_tol = 5e-2f;
// 10% for relative error
float rel_tol = 1e-1f;
for (int64_t i = 0; i < size; ++i) {
float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i));
float abs_diff = fabs(diff);
float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f);
float relative_diff = abs_diff / abs_ref;
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) {
printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
return false;
}
}
return true;
}
/// Verifies the result is a GEMM
bool verify_() {
bool passed = true;
for (int32_t i = 0; i < problem_count(); ++i) {
cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i);
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i);
LayoutQ layout_Q(ldq_host.at(i));
LayoutK_T layout_K(ldk_host.at(i));
LayoutP layout_P(ldp_host.at(i));
LayoutV layout_V(ldv_host.at(i));
LayoutO layout_O(ldo_host.at(i));
MatrixCoord extent_Q{problem0.m(), problem0.k()};
MatrixCoord extent_K{problem0.n(), problem0.k()};
MatrixCoord extent_P{problem0.m(), problem0.n()};
MatrixCoord extent_V{problem1.k(), problem1.n()};
MatrixCoord extent_O{problem1.m(), problem1.k()};
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
cutlass::TensorView<ElementK, LayoutK_T> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
cutlass::TensorView<ElementP, LayoutP> view_P(block_P.get() + offset_P.at(i), layout_P, extent_P);
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
cutlass::DeviceAllocation<ElementP> block_Ref(layout_P.capacity(extent_P));
cutlass::TensorView<ElementP, LayoutP> view_Ref_device(block_Ref.get(), layout_P, extent_P);
cutlass::DeviceAllocation<ElementO> block_Ref_O(layout_O.capacity(extent_O));
cutlass::TensorView<ElementO, LayoutO> view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O);
// Reference GEMM
cutlass::reference::device::GemmComplex<
ElementQ, LayoutQ,
ElementK, LayoutK_T,
ElementP, LayoutP,
ElementCompute, ElementAccumulator
>(
problem0,
ElementAccumulator(options.alpha0),
view_Q,
Attention::MM0::Mma::kTransformA,
view_K,
Attention::MM0::Mma::kTransformB,
ElementAccumulator(options.beta),
view_P,
view_Ref_device,
ElementAccumulator(0)
);
// Compute softmax for P. We need to explicitly compute softmax
// over P because softmax is fused to the second GEMM in the
// profiled implementation.
std::vector<ElementP> matrix_Ref(layout_P.capacity(extent_P));
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size());
cutlass::TensorView<ElementP, LayoutP> view_Ref_host(matrix_Ref.data(), layout_P, extent_P);
std::vector<ElementNorm> vector_Norm_Ref(problem0.m());
std::vector<ElementSum> vector_Sum_Ref(problem0.m());
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
// Compute softmax for referece matrix
// Assumed a row-major storage
for (int m = 0; m < problem0.m(); m++) {
int n_dim_row = n_dim;
if (options.causal) {
n_dim_row = std::min(m + 1, n_dim);
}
ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0}));
for (int n = 1; n < n_dim_row; n++) {
max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})));
}
vector_Norm_Ref.at(m) = ElementNorm(max);
ElementSoftmaxCompute sum = ElementSoftmaxCompute();
for (int n = 0; n < n_dim_row; n++) {
sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max );
}
ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum);
vector_Sum_Ref.at(m) = ElementSum(inv_sum);
for (int n = 0; n < n_dim_row; n++) {
view_Ref_host.ref().at({m, n}) = ElementP(
std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum
);
}
// Mask out the rest of the attention matrix
for (int n = n_dim_row; n < n_dim; ++n) {
view_Ref_host.ref().at({m, n}) = ElementP(0);
}
}
// when not using mask, problem_real and problem share the same sizes
if (options.use_mask) {
for (int m = 0; m < problem0.m(); m++) {
for (int n = n_dim; n < problem0.n(); n++) {
view_Ref_host.ref().at({m, n}) = ElementP(0);
}
}
}
cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size());
// Reference GEMM
cutlass::reference::device::GemmComplex<
ElementP, LayoutP,
ElementV, LayoutV,
ElementO, LayoutO,
ElementCompute, ElementAccumulator
>(
problem1,
ElementAccumulator(options.alpha1),
view_P,
Attention::MM0::Mma::kTransformA,
view_V,
Attention::MM0::Mma::kTransformB,
ElementAccumulator(options.beta),
view_Ref_O_device,
view_Ref_O_device,
ElementAccumulator(0)
);
// Copy to host memory
cutlass::TensorView<ElementP, LayoutP> view_Ref(matrix_Ref.data(), layout_P, extent_P);
std::vector<ElementO> matrix_O(layout_O.capacity(extent_O));
cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size());
std::vector<ElementO> matrix_Ref_O(layout_O.capacity(extent_O));
cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size());
// printf("Pb %d: \n Q=(offset=%d, ldq=%d)\n K=(offset=%d, ldk=%d)\n O=(offset=%d, ldo=%d)\n",
// int(i), int(offset_Q[i]), int(ldq_host[i]), int(offset_K[i]), int(ldk_host[i]), int(offset_O[i]), int(ldo_host[i]));
bool verified_O = false;
if (!verified_O) {
verified_O = verify_tensor_<ElementO>(matrix_O, matrix_Ref_O);
}
passed = passed && verified_O;
if (!passed) {
std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl;
if (!verified_O) {
std::cout << "Final matrix output is incorrect" << std::endl;
}
return passed;
}
}
return passed;
}
public:
/// Executes a CUTLASS Attention kernel and measures runtime.
Result profile_grouped() {
Result result;
result.passed = false;
// Initialize the problem
initialize_();
typename Attention::Params p;
{ // set parameters
p.query_ptr = block_Q.get();
p.key_ptr = block_K.get();
p.value_ptr = block_V.get();
p.logsumexp_ptr = nullptr; // Only needed for bw
p.output_accum_ptr = nullptr;
if (Attention::kNeedsOutputAccumulatorBuffer) {
cudaMalloc(&p.output_accum_ptr, block_O.size() * sizeof(typename Attention::output_accum_t));
}
p.output_ptr = block_O.get();
// TODO: support arbitrary seq lengths
// if (cu_seqlens_q.has_value()) {
// p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr();
// p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr();
// }
p.num_heads = options.head_number;
p.num_batches = options.batch_size;
p.head_dim = options.head_size;
p.head_dim_value = options.head_size_v;
p.num_queries = options.seq_length;
p.num_keys = options.seq_length_kv;
p.causal = options.causal;
// TODO: This might overflow for big tensors
p.q_strideM = int32_t(ldq_host[0]);
p.k_strideM = int32_t(ldk_host[0]);
p.v_strideM = int32_t(ldv_host[0]);
p.q_strideH = p.q_strideM * options.seq_length;
p.k_strideH = p.k_strideM * options.seq_length_kv;
p.v_strideH = p.v_strideM * options.seq_length_kv;
p.o_strideH = options.head_size_v * options.seq_length;
p.q_strideB = p.q_strideH * options.head_number;
p.k_strideB = p.k_strideH * options.head_number;
p.v_strideB = p.v_strideH * options.head_number;
p.o_strideB = options.head_size_v * options.seq_length * options.head_number;
}
// launch kernel :)
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
if (!Attention::check_supported(p)) {
std::cerr << "Kernel does not support these inputs" << std::endl;
return result;
}
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
// Wait for completion
result.error = cudaDeviceSynchronize();
if (result.error != cudaSuccess) {
std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error);
return result;
}
//
// Verify correctness
//
result.passed = true;
if (options.reference_check) {
result.passed = verify_();
}
//
// Warm-up run of the grouped GEMM object
//
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl;
return result;
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)cudaEventDestroy(event);
}
std::cout << std::endl;
std::cout << "CUTLASS Attention:\n"
<< "====================================================" << std::endl;
std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \
<< ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\
<< ", " << options.batch_size << "}." << std::endl;
std::cout << std::endl;
std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << "GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
//
// This example requires an NVIDIA Ampere-architecture GPU.
//
std::cout
<< "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or "
<< "later (compute capability 80 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
if (options.use_mask) {
std::cerr << "--use_mask is not supported at the moment\n";
return -2;
}
if (options.alignment != 1) {
std::cerr << "--alignment=1 is the only supported value\n";
return -2;
}
using ArchTag = cutlass::arch::Sm80;
constexpr bool kIs64x64 = true;
// Set grid size
constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32;
constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128;
if (kIs64x64 && options.head_size_v > kKeysPerBlock) {
std::cerr << "WARNING: you will get better performance with `kIs64x64=false`\n";
}
constexpr bool kSingleValueIteration = true;
if (kSingleValueIteration && options.head_size_v > kKeysPerBlock) {
std::cerr << "ERROR : Use kSingleValueIteration to keep output in RF. " \
"This requires to have `head_size <= kKeysPerBlock` " \
"but head_size_v=" << options.head_size_v << " and kKeysPerBlock=" << kKeysPerBlock << "\n";
return -2;
}
if (!kSingleValueIteration && options.head_size_v <= kKeysPerBlock) {
std::cerr << "WARNING: you will get better performance with `kSingleValueIteration=true` (keeps the output in RF rather than GMEM)\n";
}
using Attention = AttentionKernel<
cutlass::half_t, // scalar_t
ArchTag,
true, // memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration
>;
//
// Test and profile
//
TestbedAttention<Attention> testbed(options);
Result result = testbed.profile_grouped();
if (!result.passed) {
std::cout << "Profiling CUTLASS attention has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////