xFormer updates to fMHA FW (#773)
* xFormer updates to fMHA FW * Convert format to BMHK for '41_fused_multi_head_attention_fixed_seqlen' * Add missing files * Remove xFormers specific code * Update fused_multihead_attention_fixed_seqlen.cu * rebase and solve conflicts * remove white space --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
5ff5209ed5
commit
2e10404d26
@ -47,19 +47,52 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Print on the first thread of the first block
|
// Print on the first thread of the first block
|
||||||
#if 0
|
#if 1
|
||||||
#define PRINT_WARP_ID 0
|
#define PRINT_WARP_ID 0
|
||||||
#define PRINT_LANE_ID 0
|
#define PRINT_LANE_ID 0
|
||||||
#define PRINT_T0_L0(msg, ...) \
|
#define PRINT_T0_L0(msg, ...) \
|
||||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
|
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
|
||||||
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
||||||
threadIdx.z == 0) { \
|
threadIdx.z == 0) { \
|
||||||
printf(msg "\n", __VA_ARGS__); \
|
printf(msg "\n", ##__VA_ARGS__); \
|
||||||
}
|
}
|
||||||
|
#define PRINT_TX_LX(msg, ...) \
|
||||||
|
for (int bx = 0; bx < gridDim.x; ++bx) { \
|
||||||
|
for (int by = 0; by < gridDim.y; ++by) { \
|
||||||
|
for (int bz = 0; bz < gridDim.z; ++bz) { \
|
||||||
|
for (int tx = 0; tx < blockDim.x; ++tx) { \
|
||||||
|
for (int ty = 0; ty < blockDim.y; ++ty) { \
|
||||||
|
for (int tz = 0; tz < blockDim.z; ++tz) { \
|
||||||
|
__syncthreads(); \
|
||||||
|
if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
|
||||||
|
threadIdx.x == tx && threadIdx.y == ty && \
|
||||||
|
threadIdx.z == tz) { \
|
||||||
|
printf( \
|
||||||
|
"[%d,%d,%d][%d,%d,%d]" msg "\n", \
|
||||||
|
bx, \
|
||||||
|
by, \
|
||||||
|
bz, \
|
||||||
|
tx, \
|
||||||
|
ty, \
|
||||||
|
tz, \
|
||||||
|
##__VA_ARGS__); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#define PRINT_T0_L0
|
||||||
|
#define PRINT_TX_LX
|
||||||
|
#endif
|
||||||
|
|
||||||
struct __string_view {
|
struct __string_view {
|
||||||
char const* data;
|
char const* data;
|
||||||
std::size_t size;
|
std::size_t size;
|
||||||
};
|
};
|
||||||
|
#if __cplusplus >= 201402L
|
||||||
template <class T>
|
template <class T>
|
||||||
constexpr __string_view __get_type_name() {
|
constexpr __string_view __get_type_name() {
|
||||||
char const* p = __PRETTY_FUNCTION__;
|
char const* p = __PRETTY_FUNCTION__;
|
||||||
@ -83,7 +116,10 @@ constexpr __string_view __get_type_name() {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
#define PRINT_T0_L0
|
template <class T>
|
||||||
|
constexpr __string_view __get_type_name() {
|
||||||
|
return {"unsupported", 11};
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Print a given array
|
// Print a given array
|
||||||
|
|||||||
@ -168,6 +168,9 @@ public:
|
|||||||
typename LayoutP::Stride::LongIndex *ldv;
|
typename LayoutP::Stride::LongIndex *ldv;
|
||||||
typename LayoutO::Stride::LongIndex *ldo;
|
typename LayoutO::Stride::LongIndex *ldo;
|
||||||
|
|
||||||
|
// Scale
|
||||||
|
ElementAccumulator scale;
|
||||||
|
|
||||||
// Whether causal masking is to be performed
|
// Whether causal masking is to be performed
|
||||||
bool causal;
|
bool causal;
|
||||||
|
|
||||||
@ -193,6 +196,7 @@ public:
|
|||||||
ldk(nullptr),
|
ldk(nullptr),
|
||||||
ldv(nullptr),
|
ldv(nullptr),
|
||||||
ldo(nullptr),
|
ldo(nullptr),
|
||||||
|
scale(0),
|
||||||
causal(false),
|
causal(false),
|
||||||
host_problem_sizes(nullptr)
|
host_problem_sizes(nullptr)
|
||||||
{
|
{
|
||||||
@ -218,6 +222,7 @@ public:
|
|||||||
typename LayoutV::Stride::LongIndex *ldv,
|
typename LayoutV::Stride::LongIndex *ldv,
|
||||||
typename LayoutO::Stride::LongIndex *ldo,
|
typename LayoutO::Stride::LongIndex *ldo,
|
||||||
bool causal,
|
bool causal,
|
||||||
|
ElementAccumulator scale,
|
||||||
GemmCoord *host_problem_sizes=nullptr
|
GemmCoord *host_problem_sizes=nullptr
|
||||||
):
|
):
|
||||||
problem_sizes0(problem_sizes0),
|
problem_sizes0(problem_sizes0),
|
||||||
@ -235,6 +240,7 @@ public:
|
|||||||
ldv(ldv),
|
ldv(ldv),
|
||||||
ldo(ldo),
|
ldo(ldo),
|
||||||
causal(causal),
|
causal(causal),
|
||||||
|
scale(scale),
|
||||||
host_problem_sizes(host_problem_sizes)
|
host_problem_sizes(host_problem_sizes)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -273,6 +279,7 @@ public:
|
|||||||
typename LayoutP::Stride::LongIndex *ldv;
|
typename LayoutP::Stride::LongIndex *ldv;
|
||||||
typename LayoutO::Stride::LongIndex *ldo;
|
typename LayoutO::Stride::LongIndex *ldo;
|
||||||
|
|
||||||
|
ElementAccumulator scale;
|
||||||
bool causal;
|
bool causal;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -291,7 +298,8 @@ public:
|
|||||||
ldk(nullptr),
|
ldk(nullptr),
|
||||||
ldv(nullptr),
|
ldv(nullptr),
|
||||||
ldo(nullptr),
|
ldo(nullptr),
|
||||||
causal(false)
|
causal(false),
|
||||||
|
scale(0)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
@ -310,8 +318,9 @@ public:
|
|||||||
ldk(args.ldk),
|
ldk(args.ldk),
|
||||||
ldv(args.ldv),
|
ldv(args.ldv),
|
||||||
ldo(args.ldo),
|
ldo(args.ldo),
|
||||||
causal(args.causal)
|
causal(args.causal),
|
||||||
{
|
scale(args.scale)
|
||||||
|
{
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,6 +346,7 @@ public:
|
|||||||
ldv = args.ldv;
|
ldv = args.ldv;
|
||||||
ldo = args.ldo;
|
ldo = args.ldo;
|
||||||
causal = args.causal;
|
causal = args.causal;
|
||||||
|
scale = args.scale;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -649,7 +659,7 @@ public:
|
|||||||
warp_id(),
|
warp_id(),
|
||||||
num_keys - iter_key_start,
|
num_keys - iter_key_start,
|
||||||
iteratorC_tile_offset,
|
iteratorC_tile_offset,
|
||||||
1.0f / cutlass::fast_sqrt(float(problem_size0.k())));
|
params.scale);
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|||||||
@ -504,37 +504,51 @@ private:
|
|||||||
ldo_host.resize(problem_count());
|
ldo_host.resize(problem_count());
|
||||||
seqlen_host.resize(problem_count());
|
seqlen_host.resize(problem_count());
|
||||||
|
|
||||||
for (int32_t i = 0; i < problem_count(); ++i) {
|
// Create tensors in BMHK format, where
|
||||||
|
// B = batch_size
|
||||||
|
// M = sequence length
|
||||||
|
// H = num_heads
|
||||||
|
// K = embedding size per head
|
||||||
|
int64_t batch_offset_Q, batch_offset_K, batch_offset_V, batch_offset_O;
|
||||||
|
|
||||||
auto problem0 = options.problem_sizes0.at(i);
|
for (int32_t b = 0; b < options.batch_size; ++b) {
|
||||||
auto problem1 = options.problem_sizes1.at(i);
|
batch_offset_Q = total_elements_Q;
|
||||||
|
batch_offset_K = total_elements_K;
|
||||||
|
batch_offset_V = total_elements_V;
|
||||||
|
batch_offset_O = total_elements_O;
|
||||||
|
for (int32_t h = 0; h < options.head_number; ++h) {
|
||||||
|
int32_t i = h + b * options.head_number;
|
||||||
|
|
||||||
ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0);
|
auto problem0 = options.problem_sizes0.at(i);
|
||||||
ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0);
|
auto problem1 = options.problem_sizes1.at(i);
|
||||||
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
|
|
||||||
ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0);
|
|
||||||
ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0);
|
|
||||||
|
|
||||||
// m = n for attention problems.
|
ldq_host.at(i) = LayoutQ::packed({problem0.m(), options.head_number * problem0.k()}).stride(0);
|
||||||
seqlen_host.at(i) = problem0.m();
|
ldk_host.at(i) = LayoutK::packed({options.head_number * problem0.k(), problem0.n()}).stride(0);
|
||||||
|
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
|
||||||
|
ldv_host.at(i) = LayoutV::packed({problem1.k(), options.head_number * problem1.n()}).stride(0);
|
||||||
|
ldo_host.at(i) = LayoutO::packed({problem1.m(), options.head_number * problem1.n()}).stride(0);
|
||||||
|
|
||||||
offset_Q.push_back(total_elements_Q);
|
// m = n for attention problems.
|
||||||
offset_K.push_back(total_elements_K);
|
seqlen_host.at(i) = problem0.m();
|
||||||
offset_P.push_back(total_elements_P);
|
|
||||||
offset_V.push_back(total_elements_V);
|
|
||||||
offset_O.push_back(total_elements_O);
|
|
||||||
|
|
||||||
int64_t elements_Q = problem0.m() * problem0.k();
|
offset_Q.push_back(batch_offset_Q + h * problem0.k());
|
||||||
int64_t elements_K = problem0.k() * problem0.n();
|
offset_K.push_back(batch_offset_K + h * problem0.k());
|
||||||
int64_t elements_P = problem0.m() * problem0.n();
|
offset_P.push_back(total_elements_P);
|
||||||
int64_t elements_V = problem1.k() * problem1.n();
|
offset_V.push_back(batch_offset_V + h * problem0.k());
|
||||||
int64_t elements_O = problem1.m() * problem1.n();
|
offset_O.push_back(batch_offset_O + h * problem1.n());
|
||||||
|
|
||||||
total_elements_Q += elements_Q;
|
int64_t elements_Q = problem0.m() * problem0.k();
|
||||||
total_elements_K += elements_K;
|
int64_t elements_K = problem0.k() * problem0.n();
|
||||||
total_elements_P += elements_P;
|
int64_t elements_P = problem0.m() * problem0.n();
|
||||||
total_elements_V += elements_V;
|
int64_t elements_V = problem1.k() * problem1.n();
|
||||||
total_elements_O += elements_O;
|
int64_t elements_O = problem1.m() * problem1.n();
|
||||||
|
|
||||||
|
total_elements_Q += elements_Q;
|
||||||
|
total_elements_K += elements_K;
|
||||||
|
total_elements_P += elements_P;
|
||||||
|
total_elements_V += elements_V;
|
||||||
|
total_elements_O += elements_O;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
problem_sizes_device0.reset(problem_count());
|
problem_sizes_device0.reset(problem_count());
|
||||||
@ -649,15 +663,11 @@ private:
|
|||||||
|
|
||||||
bool passed = true;
|
bool passed = true;
|
||||||
|
|
||||||
for (int32_t i = 0; i < problem_count(); ++i) {
|
for (int32_t b = 0; b < options.batch_size; ++b) {
|
||||||
cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i);
|
int32_t i = b * options.head_number;
|
||||||
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i);
|
// Problem size is the same for all heads
|
||||||
|
cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(b * options.head_number);
|
||||||
LayoutQ layout_Q(ldq_host.at(i));
|
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(b * options.head_number);
|
||||||
LayoutK layout_K(ldk_host.at(i));
|
|
||||||
LayoutP layout_P(ldp_host.at(i));
|
|
||||||
LayoutV layout_V(ldv_host.at(i));
|
|
||||||
LayoutO layout_O(ldo_host.at(i));
|
|
||||||
|
|
||||||
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
||||||
MatrixCoord extent_K{problem0.k(), problem0.n()};
|
MatrixCoord extent_K{problem0.k(), problem0.n()};
|
||||||
@ -665,114 +675,121 @@ private:
|
|||||||
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
||||||
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
||||||
|
|
||||||
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
LayoutO layout_O(ldo_host.at(i));
|
||||||
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
std::vector<ElementO> matrix_O(layout_O.capacity(extent_O));
|
||||||
cutlass::TensorView<ElementP, LayoutP> view_P(block_P.get() + offset_P.at(i), layout_P, extent_P);
|
cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size());
|
||||||
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
|
|
||||||
|
|
||||||
cutlass::DeviceAllocation<ElementP> block_Ref(layout_P.capacity(extent_P));
|
|
||||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_device(block_Ref.get(), layout_P, extent_P);
|
|
||||||
|
|
||||||
cutlass::DeviceAllocation<ElementO> block_Ref_O(layout_O.capacity(extent_O));
|
cutlass::DeviceAllocation<ElementO> block_Ref_O(layout_O.capacity(extent_O));
|
||||||
cutlass::TensorView<ElementO, LayoutO> view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O);
|
|
||||||
|
|
||||||
// Reference GEMM
|
for (int32_t h = 0; h < options.head_number; ++h) {
|
||||||
cutlass::reference::device::GemmComplex<
|
i = h + b * options.head_number;
|
||||||
ElementQ, LayoutQ,
|
|
||||||
ElementK, LayoutK,
|
|
||||||
ElementP, LayoutP,
|
|
||||||
ElementCompute, ElementAccumulator
|
|
||||||
>(
|
|
||||||
problem0,
|
|
||||||
ElementAccumulator(options.alpha0),
|
|
||||||
view_Q,
|
|
||||||
Attention::MM0::Mma::kTransformA,
|
|
||||||
view_K,
|
|
||||||
Attention::MM0::Mma::kTransformB,
|
|
||||||
ElementAccumulator(options.beta),
|
|
||||||
view_P,
|
|
||||||
view_Ref_device,
|
|
||||||
ElementAccumulator(0)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Compute softmax for P. We need to explicitly compute softmax
|
LayoutQ layout_Q(ldq_host.at(i));
|
||||||
// over P because softmax is fused to the second GEMM in the
|
LayoutK layout_K(ldk_host.at(i));
|
||||||
// profiled implementation.
|
LayoutP layout_P(ldp_host.at(i));
|
||||||
std::vector<ElementP> matrix_Ref(layout_P.capacity(extent_P));
|
LayoutV layout_V(ldv_host.at(i));
|
||||||
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size());
|
|
||||||
cutlass::TensorView<ElementP, LayoutP> view_Ref_host(matrix_Ref.data(), layout_P, extent_P);
|
|
||||||
std::vector<ElementNorm> vector_Norm_Ref(problem0.m());
|
|
||||||
std::vector<ElementSum> vector_Sum_Ref(problem0.m());
|
|
||||||
|
|
||||||
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
||||||
|
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||||
|
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
|
||||||
|
cutlass::TensorView<ElementO, LayoutO> view_Ref_O_device(block_Ref_O.get() + offset_O.at(i) - offset_O.at(b * options.head_number), layout_O, extent_O);
|
||||||
|
|
||||||
// Compute softmax for referece matrix
|
cutlass::DeviceAllocation<ElementP> block_Ref_P(layout_P.capacity(extent_P));
|
||||||
for (int m = 0; m < problem0.m(); m++) {
|
cutlass::TensorView<ElementP, LayoutP> view_Ref_P_device(block_Ref_P.get(), layout_P, extent_P);
|
||||||
int n_dim_row = n_dim;
|
|
||||||
if (options.causal) {
|
|
||||||
n_dim_row = std::min(m + 1, n_dim);
|
|
||||||
}
|
|
||||||
ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0}));
|
|
||||||
for (int n = 1; n < n_dim_row; n++) {
|
|
||||||
max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})));
|
|
||||||
}
|
|
||||||
|
|
||||||
vector_Norm_Ref.at(m) = ElementNorm(max);
|
// Reference GEMM
|
||||||
|
cutlass::reference::device::GemmComplex<
|
||||||
|
ElementQ, LayoutQ,
|
||||||
|
ElementK, LayoutK,
|
||||||
|
ElementP, LayoutP,
|
||||||
|
ElementCompute, ElementAccumulator
|
||||||
|
>(
|
||||||
|
problem0,
|
||||||
|
ElementAccumulator(options.alpha0),
|
||||||
|
view_Q,
|
||||||
|
Attention::MM0::Mma::kTransformA,
|
||||||
|
view_K,
|
||||||
|
Attention::MM0::Mma::kTransformB,
|
||||||
|
ElementAccumulator(options.beta),
|
||||||
|
view_Ref_P_device,
|
||||||
|
view_Ref_P_device,
|
||||||
|
ElementAccumulator(0)
|
||||||
|
);
|
||||||
|
|
||||||
ElementSoftmaxCompute sum = ElementSoftmaxCompute();
|
// Compute softmax for P. We need to explicitly compute softmax
|
||||||
for (int n = 0; n < n_dim_row; n++) {
|
// over P because softmax is fused to the second GEMM in the
|
||||||
sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max );
|
// profiled implementation.
|
||||||
}
|
std::vector<ElementP> matrix_Ref(layout_P.capacity(extent_P));
|
||||||
ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum);
|
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref_P.get(), matrix_Ref.size());
|
||||||
|
cutlass::TensorView<ElementP, LayoutP> view_Ref_host(matrix_Ref.data(), layout_P, extent_P);
|
||||||
|
std::vector<ElementNorm> vector_Norm_Ref(problem0.m());
|
||||||
|
std::vector<ElementSum> vector_Sum_Ref(problem0.m());
|
||||||
|
|
||||||
vector_Sum_Ref.at(m) = ElementSum(inv_sum);
|
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
||||||
|
|
||||||
for (int n = 0; n < n_dim_row; n++) {
|
// Compute softmax for reference matrix
|
||||||
view_Ref_host.ref().at({m, n}) = ElementP(
|
|
||||||
std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Mask out the rest of the attention matrix
|
|
||||||
for (int n = n_dim_row; n < n_dim; ++n) {
|
|
||||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// when not using mask, problem_real and problem share the same sizes
|
|
||||||
if (options.use_mask) {
|
|
||||||
for (int m = 0; m < problem0.m(); m++) {
|
for (int m = 0; m < problem0.m(); m++) {
|
||||||
for (int n = n_dim; n < problem0.n(); n++) {
|
int n_dim_row = n_dim;
|
||||||
|
if (options.causal) {
|
||||||
|
n_dim_row = std::min(m + 1, n_dim);
|
||||||
|
}
|
||||||
|
ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0}));
|
||||||
|
for (int n = 1; n < n_dim_row; n++) {
|
||||||
|
max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})));
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_Norm_Ref.at(m) = ElementNorm(max);
|
||||||
|
|
||||||
|
ElementSoftmaxCompute sum = ElementSoftmaxCompute();
|
||||||
|
for (int n = 0; n < n_dim_row; n++) {
|
||||||
|
sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max );
|
||||||
|
}
|
||||||
|
ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum);
|
||||||
|
|
||||||
|
vector_Sum_Ref.at(m) = ElementSum(inv_sum);
|
||||||
|
|
||||||
|
for (int n = 0; n < n_dim_row; n++) {
|
||||||
|
view_Ref_host.ref().at({m, n}) = ElementP(
|
||||||
|
std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Mask out the rest of the attention matrix
|
||||||
|
for (int n = n_dim_row; n < n_dim; ++n) {
|
||||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when not using mask, problem_real and problem share the same sizes
|
||||||
|
if (options.use_mask) {
|
||||||
|
for (int m = 0; m < problem0.m(); m++) {
|
||||||
|
for (int n = n_dim; n < problem0.n(); n++) {
|
||||||
|
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cutlass::device_memory::copy_to_device(block_Ref_P.get(), matrix_Ref.data(), matrix_Ref.size());
|
||||||
|
|
||||||
|
// Reference GEMM
|
||||||
|
cutlass::reference::device::GemmComplex<
|
||||||
|
ElementP, LayoutP,
|
||||||
|
ElementV, LayoutV,
|
||||||
|
ElementO, LayoutO,
|
||||||
|
ElementCompute, ElementAccumulator
|
||||||
|
>(
|
||||||
|
problem1,
|
||||||
|
ElementAccumulator(options.alpha1),
|
||||||
|
view_Ref_P_device,
|
||||||
|
Attention::MM0::Mma::kTransformA,
|
||||||
|
view_V,
|
||||||
|
Attention::MM0::Mma::kTransformB,
|
||||||
|
ElementAccumulator(options.beta),
|
||||||
|
view_Ref_O_device,
|
||||||
|
view_Ref_O_device,
|
||||||
|
ElementAccumulator(0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size());
|
|
||||||
|
|
||||||
// Reference GEMM
|
|
||||||
cutlass::reference::device::GemmComplex<
|
|
||||||
ElementP, LayoutP,
|
|
||||||
ElementV, LayoutV,
|
|
||||||
ElementO, LayoutO,
|
|
||||||
ElementCompute, ElementAccumulator
|
|
||||||
>(
|
|
||||||
problem1,
|
|
||||||
ElementAccumulator(options.alpha1),
|
|
||||||
view_P,
|
|
||||||
Attention::MM0::Mma::kTransformA,
|
|
||||||
view_V,
|
|
||||||
Attention::MM0::Mma::kTransformB,
|
|
||||||
ElementAccumulator(options.beta),
|
|
||||||
view_Ref_O_device,
|
|
||||||
view_Ref_O_device,
|
|
||||||
ElementAccumulator(0)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Copy to host memory
|
// Copy to host memory
|
||||||
cutlass::TensorView<ElementP, LayoutP> view_Ref(matrix_Ref.data(), layout_P, extent_P);
|
|
||||||
|
|
||||||
std::vector<ElementO> matrix_O(layout_O.capacity(extent_O));
|
|
||||||
cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size());
|
|
||||||
std::vector<ElementO> matrix_Ref_O(layout_O.capacity(extent_O));
|
std::vector<ElementO> matrix_Ref_O(layout_O.capacity(extent_O));
|
||||||
cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size());
|
cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size());
|
||||||
|
|
||||||
@ -788,7 +805,7 @@ private:
|
|||||||
passed = passed && verified_O;
|
passed = passed && verified_O;
|
||||||
|
|
||||||
if (!passed) {
|
if (!passed) {
|
||||||
std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl;
|
std::cerr << "\n***\nError - problem " << i << " (batch " << b << ") failed the QA check\n***\n" << std::endl;
|
||||||
|
|
||||||
if (!verified_O) {
|
if (!verified_O) {
|
||||||
std::cout << "Final matrix output is incorrect" << std::endl;
|
std::cout << "Final matrix output is incorrect" << std::endl;
|
||||||
@ -831,6 +848,8 @@ public:
|
|||||||
// p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr();
|
// p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr();
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
p.scale = options.alpha0;
|
||||||
|
|
||||||
p.num_heads = options.head_number;
|
p.num_heads = options.head_number;
|
||||||
p.num_batches = options.batch_size;
|
p.num_batches = options.batch_size;
|
||||||
p.head_dim = options.head_size;
|
p.head_dim = options.head_size;
|
||||||
@ -839,18 +858,16 @@ public:
|
|||||||
p.num_keys = options.seq_length_kv;
|
p.num_keys = options.seq_length_kv;
|
||||||
p.causal = options.causal;
|
p.causal = options.causal;
|
||||||
|
|
||||||
// TODO: This might overflow for big tensors
|
// All tensors are in BMHK shapes
|
||||||
|
p.q_strideH = options.head_size;
|
||||||
|
p.k_strideH = options.head_size;
|
||||||
|
p.v_strideH = options.head_size_v;
|
||||||
p.q_strideM = int32_t(ldq_host[0]);
|
p.q_strideM = int32_t(ldq_host[0]);
|
||||||
p.k_strideM = int32_t(ldk_host[0]);
|
p.k_strideM = int32_t(ldk_host[0]);
|
||||||
p.v_strideM = int32_t(ldv_host[0]);
|
p.v_strideM = int32_t(ldv_host[0]);
|
||||||
p.q_strideH = p.q_strideM * options.seq_length;
|
p.q_strideB = p.q_strideM * options.seq_length;
|
||||||
p.k_strideH = p.k_strideM * options.seq_length_kv;
|
p.k_strideB = p.k_strideM * options.seq_length_kv;
|
||||||
p.v_strideH = p.v_strideM * options.seq_length_kv;
|
p.v_strideB = p.v_strideM * options.seq_length_kv;
|
||||||
p.o_strideH = options.head_size_v * options.seq_length;
|
|
||||||
p.q_strideB = p.q_strideH * options.head_number;
|
|
||||||
p.k_strideB = p.k_strideH * options.head_number;
|
|
||||||
p.v_strideB = p.v_strideH * options.head_number;
|
|
||||||
p.o_strideB = options.head_size_v * options.seq_length * options.head_number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// launch kernel :)
|
// launch kernel :)
|
||||||
|
|||||||
@ -921,6 +921,7 @@ public:
|
|||||||
ldv.get(),
|
ldv.get(),
|
||||||
ldo.get(),
|
ldo.get(),
|
||||||
options.causal,
|
options.causal,
|
||||||
|
options.alpha0,
|
||||||
options.problem_sizes1.data()
|
options.problem_sizes1.data()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@ -36,20 +36,20 @@
|
|||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Some helper functions
|
// Some helper functions
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
#define DISPATCH_TYPES(tensor, func) \
|
#define DISPATCH_TYPES(tensor, func) \
|
||||||
{ \
|
{ \
|
||||||
if (query.scalar_type() == at::ScalarType::Float) { \
|
if (query.scalar_type() == at::ScalarType::Float) { \
|
||||||
using scalar_t = float; \
|
using scalar_t = float; \
|
||||||
func(); \
|
func(); \
|
||||||
} else if (query.scalar_type() == at::ScalarType::Half) { \
|
} else if (query.scalar_type() == at::ScalarType::Half) { \
|
||||||
using scalar_t = cutlass::half_t; \
|
using scalar_t = cutlass::half_t; \
|
||||||
func(); \
|
func(); \
|
||||||
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
|
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
|
||||||
using scalar_t = cutlass::bfloat16_t; \
|
using scalar_t = cutlass::bfloat16_t; \
|
||||||
func(); \
|
func(); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
|
XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
||||||
@ -77,26 +77,27 @@
|
|||||||
using ArchTag = cutlass::arch::Sm50; \
|
using ArchTag = cutlass::arch::Sm50; \
|
||||||
func(); \
|
func(); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK( \
|
XFORMERS_CHECK( \
|
||||||
false, \
|
false, \
|
||||||
"Your device is too old. We require compute capability >= 50"); \
|
"Your device is too old. We require compute capability >= 50"); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||||
TORCH_CHECK(TENSOR.is_contiguous());
|
XFORMERS_CHECK(TENSOR.is_contiguous());
|
||||||
|
|
||||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||||
TORCH_CHECK( \
|
XFORMERS_CHECK( \
|
||||||
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
||||||
|
|
||||||
#ifdef HAS_PYTORCH
|
#ifdef TORCH_CHECK
|
||||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||||
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
XFORMERS_CHECK( \
|
||||||
|
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||||
#define XFORMERS_CHECK TORCH_CHECK
|
#define XFORMERS_CHECK TORCH_CHECK
|
||||||
#elif defined(__CUDACC_RTC__)
|
#elif defined(__CUDACC_RTC__)
|
||||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||||
@ -108,6 +109,7 @@
|
|||||||
return false; \
|
return false; \
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
#include <iostream>
|
||||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||||
@ -120,74 +122,25 @@
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||||
{ \
|
{ \
|
||||||
A = B; \
|
A = B; \
|
||||||
TORCH_CHECK( \
|
XFORMERS_CHECK( \
|
||||||
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
|
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
|
||||||
#B " overflows"); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace gemm_kernel_utils {
|
namespace gemm_kernel_utils {
|
||||||
|
|
||||||
#ifdef HAS_PYTORCH
|
|
||||||
template <typename scalar_t>
|
|
||||||
struct TypeTraits;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeTraits<cutlass::half_t> {
|
|
||||||
using scalar_t = cutlass::half_t;
|
|
||||||
|
|
||||||
static constexpr __host__ at::ScalarType atScalarType() {
|
|
||||||
return at::ScalarType::Half;
|
|
||||||
}
|
|
||||||
template <int nDim>
|
|
||||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
|
||||||
at::Tensor const& tensor) {
|
|
||||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
|
||||||
(scalar_t*)(tensor.data_ptr()),
|
|
||||||
tensor.sizes().data(),
|
|
||||||
tensor.strides().data());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeTraits<cutlass::bfloat16_t> {
|
|
||||||
using scalar_t = cutlass::bfloat16_t;
|
|
||||||
|
|
||||||
static constexpr __host__ at::ScalarType atScalarType() {
|
|
||||||
return at::ScalarType::BFloat16;
|
|
||||||
}
|
|
||||||
template <int nDim>
|
|
||||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
|
||||||
at::Tensor const& tensor) {
|
|
||||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
|
||||||
(scalar_t*)(tensor.data_ptr()),
|
|
||||||
tensor.sizes().data(),
|
|
||||||
tensor.strides().data());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeTraits<float> {
|
|
||||||
using scalar_t = float;
|
|
||||||
|
|
||||||
static constexpr __host__ at::ScalarType atScalarType() {
|
|
||||||
return at::ScalarType::Float;
|
|
||||||
}
|
|
||||||
template <int nDim>
|
|
||||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
|
||||||
at::Tensor const& tensor) {
|
|
||||||
return tensor.packed_accessor32<scalar_t, nDim>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename integer>
|
template <typename integer>
|
||||||
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
|
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
|
||||||
return (n + m - 1) / m;
|
return (n + m - 1) / m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename integer>
|
||||||
|
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
|
||||||
|
return ((n + m - 1) / m) * m;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
|
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
|
||||||
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
|
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
|
||||||
|
|||||||
@ -311,9 +311,9 @@ class PredicatedTileIteratorPrefetch {
|
|||||||
// on windows using unsigned long here gives the error
|
// on windows using unsigned long here gives the error
|
||||||
// error: asm operand type size(4) does not match
|
// error: asm operand type size(4) does not match
|
||||||
// type/size implied by constraint 'l'
|
// type/size implied by constraint 'l'
|
||||||
uint64_t addr = (uint64_t)(
|
uint64_t addr = (uint64_t)((void*)&memory_pointer
|
||||||
(void*)&memory_pointer
|
[column * ThreadMap::Delta::kColumn /
|
||||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
|
kElementsPerAccess]);
|
||||||
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,53 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "warp_iterator_from_smem.h"
|
||||||
|
|
||||||
|
template <typename WarpIterator>
|
||||||
|
struct TransposeWarpIterator {
|
||||||
|
using Iterator = char;
|
||||||
|
static bool constexpr kSupportsTranspose = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
/// Operand identity
|
||||||
|
cutlass::gemm::Operand Operand,
|
||||||
|
/// Data type of A elements
|
||||||
|
typename Element,
|
||||||
|
bool kTranspose>
|
||||||
|
struct TransposeWarpIterator<
|
||||||
|
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
|
||||||
|
using Iterator =
|
||||||
|
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
|
||||||
|
static bool constexpr kSupportsTranspose = true;
|
||||||
|
};
|
||||||
@ -0,0 +1,278 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Inspired from
|
||||||
|
"cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
|
||||||
|
operands from a RowMajor shared-memory layout into registers to use by A100
|
||||||
|
TensorCores.
|
||||||
|
|
||||||
|
The difference with "mma_tensor_op_tile_access_iterator.h" is that:
|
||||||
|
(1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
|
||||||
|
faster) (2) We support to transpose the operand (eg read `A.transpose()` when
|
||||||
|
the shared memory holds `A`)
|
||||||
|
|
||||||
|
This is only implemented for the specific shapes.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cutlass/gemm/gemm.h>
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace warp {
|
||||||
|
|
||||||
|
template <
|
||||||
|
/// Operand identity
|
||||||
|
Operand Operand_,
|
||||||
|
/// Data type of A elements
|
||||||
|
typename Element_,
|
||||||
|
bool kTranspose = false>
|
||||||
|
class WarpIteratorFromSmem {
|
||||||
|
public:
|
||||||
|
/// Shape of tile to load (concept: MatrixShape)
|
||||||
|
using Shape = cutlass::MatrixShape<32, 32>;
|
||||||
|
|
||||||
|
/// Operand tag
|
||||||
|
static Operand const kOperand = Operand_;
|
||||||
|
|
||||||
|
/// Basic check
|
||||||
|
static_assert(
|
||||||
|
kOperand == Operand::kA || kOperand == Operand::kB,
|
||||||
|
"WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
|
||||||
|
|
||||||
|
/// Element type
|
||||||
|
using Element = Element_;
|
||||||
|
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
|
||||||
|
|
||||||
|
/// Layout of source tile
|
||||||
|
using Layout = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||||
|
using InstructionShape = cutlass::MatrixShape<16, 8>;
|
||||||
|
|
||||||
|
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||||
|
/// MatrixShape)
|
||||||
|
static int const kOpDelta = 1;
|
||||||
|
|
||||||
|
/// Number of participating threads
|
||||||
|
static int const kThreads = 32;
|
||||||
|
|
||||||
|
/// TensorRef type for loading element from a tensor
|
||||||
|
using TensorRef = TensorRef<Element, Layout>;
|
||||||
|
|
||||||
|
/// Index type
|
||||||
|
using Index = typename TensorRef::Index;
|
||||||
|
|
||||||
|
/// Long Index type
|
||||||
|
using LongIndex = typename TensorRef::LongIndex;
|
||||||
|
|
||||||
|
/// Coordinate for an element in the tensor
|
||||||
|
using TensorCoord = typename TensorRef::TensorCoord;
|
||||||
|
|
||||||
|
/// Number of elements accessed per Shared Memory load
|
||||||
|
static int const kElementsPerAccess =
|
||||||
|
(sizeof_bits<Element>::value >= 32 ? 1
|
||||||
|
: 32 / sizeof_bits<Element>::value);
|
||||||
|
|
||||||
|
using InstructionCount = MatrixShape<
|
||||||
|
Shape::kRow / InstructionShape::kRow,
|
||||||
|
Shape::kColumn / InstructionShape::kColumn>;
|
||||||
|
|
||||||
|
static int const kIterations = (kOperand == Operand::kA)
|
||||||
|
? InstructionCount::kColumn
|
||||||
|
: InstructionCount::kRow;
|
||||||
|
|
||||||
|
public:
|
||||||
|
//
|
||||||
|
// Derived quantities
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Fragment object holding a thread's part of a tile
|
||||||
|
using Fragment = Array<
|
||||||
|
Element,
|
||||||
|
(kOperand == Operand::kA)
|
||||||
|
? (Shape::kRow* InstructionShape::kColumn / kThreads)
|
||||||
|
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
|
||||||
|
|
||||||
|
/// Memory access type
|
||||||
|
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||||
|
using AccessType = Array<unsigned, 4>;
|
||||||
|
|
||||||
|
static int constexpr kWarpShapeDivisibleInner =
|
||||||
|
(kOperand == Operand::kA ? InstructionShape::kColumn
|
||||||
|
: InstructionShape::kRow);
|
||||||
|
static int constexpr kAccessesInner =
|
||||||
|
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||||
|
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Underlying tensor reference
|
||||||
|
TensorRef ref_;
|
||||||
|
|
||||||
|
/// Origin
|
||||||
|
MatrixCoord origin_;
|
||||||
|
|
||||||
|
/// Iterations in a tile
|
||||||
|
int iterations_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Constructor from TensorRef
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
|
||||||
|
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
||||||
|
: ref_(ref), iterations_(0) {
|
||||||
|
int ldsm_vec_num = (lane_id >> 3);
|
||||||
|
if (kOperand == Operand::kA) {
|
||||||
|
origin_ = MatrixCoord(lane_id % 8, 0);
|
||||||
|
static_assert(
|
||||||
|
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
|
||||||
|
"");
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
|
||||||
|
++inst_m_idx) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
|
||||||
|
++access_m_idx) {
|
||||||
|
int access_idx = access_m_idx +
|
||||||
|
kTilesPerInstruction *
|
||||||
|
(inner_idx + kAccessesInner * inst_m_idx);
|
||||||
|
|
||||||
|
MatrixCoord offset(
|
||||||
|
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||||
|
inner_idx * 4 * kElementsPerAccess);
|
||||||
|
|
||||||
|
if (access_idx == ldsm_vec_num) {
|
||||||
|
if (kTranspose) {
|
||||||
|
offset = MatrixCoord(offset.column(), offset.row());
|
||||||
|
}
|
||||||
|
origin_ += offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
origin_ = MatrixCoord(0, lane_id % 8);
|
||||||
|
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
|
||||||
|
++inst_n_idx) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||||
|
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
|
||||||
|
|
||||||
|
MatrixCoord offset(
|
||||||
|
inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
|
||||||
|
|
||||||
|
if (access_idx == ldsm_vec_num) {
|
||||||
|
if (kTranspose) {
|
||||||
|
offset = MatrixCoord(offset.column(), offset.row());
|
||||||
|
}
|
||||||
|
origin_ += offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ref_.add_coord_offset(origin_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advances an iterator along logical dimensions of matrix in units of whole
|
||||||
|
/// tiles
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
|
||||||
|
TensorCoord coord_offset(
|
||||||
|
tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||||
|
if (kTranspose) {
|
||||||
|
coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
|
||||||
|
}
|
||||||
|
origin_ += coord_offset;
|
||||||
|
|
||||||
|
ref_.add_coord_offset(coord_offset);
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advances the iterator along the advance dimension
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void advance() {
|
||||||
|
if (kOperand == Operand::kA) {
|
||||||
|
add_tile_offset({0, 1});
|
||||||
|
} else {
|
||||||
|
add_tile_offset({1, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
iterations_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// increase iterations in a tile
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
WarpIteratorFromSmem& operator++() {
|
||||||
|
iterations_++;
|
||||||
|
|
||||||
|
if (iterations_ >= kIterations)
|
||||||
|
advance();
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void load(Fragment& frag) const {
|
||||||
|
AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
|
||||||
|
using LoadLayout = typename platform::
|
||||||
|
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
||||||
|
|
||||||
|
MatrixCoord offset;
|
||||||
|
if (kOperand == Operand::kA) {
|
||||||
|
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
|
||||||
|
} else {
|
||||||
|
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||||
|
}
|
||||||
|
if (kTranspose) {
|
||||||
|
offset = MatrixCoord(offset.column(), offset.row());
|
||||||
|
}
|
||||||
|
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||||
|
access_ptr[0], ref_.data() + ref_.offset(offset));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace warp
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -29,15 +29,6 @@
|
|||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#ifdef HAS_PYTORCH
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <torch/library.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -137,6 +128,9 @@ struct AttentionKernel {
|
|||||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||||
|
|
||||||
|
// Scale
|
||||||
|
accum_t scale;
|
||||||
|
|
||||||
// Dimensions/strides
|
// Dimensions/strides
|
||||||
int32_t head_dim;
|
int32_t head_dim;
|
||||||
int32_t head_dim_value;
|
int32_t head_dim_value;
|
||||||
@ -154,18 +148,15 @@ struct AttentionKernel {
|
|||||||
int32_t q_strideH;
|
int32_t q_strideH;
|
||||||
int32_t k_strideH;
|
int32_t k_strideH;
|
||||||
int32_t v_strideH;
|
int32_t v_strideH;
|
||||||
int32_t o_strideH;
|
|
||||||
int64_t q_strideB;
|
int64_t q_strideB;
|
||||||
int64_t k_strideB;
|
int64_t k_strideB;
|
||||||
int64_t v_strideB;
|
int64_t v_strideB;
|
||||||
int64_t o_strideB;
|
|
||||||
int32_t num_batches;
|
int32_t num_batches;
|
||||||
int32_t num_heads;
|
int32_t num_heads;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||||
return head_dim_value;
|
return head_dim_value * num_heads;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Moves pointers to what we should process
|
// Moves pointers to what we should process
|
||||||
// Returns "false" if there is no work to do
|
// Returns "false" if there is no work to do
|
||||||
CUTLASS_DEVICE bool advance_to_block() {
|
CUTLASS_DEVICE bool advance_to_block() {
|
||||||
@ -195,9 +186,9 @@ struct AttentionKernel {
|
|||||||
query_ptr += batch_id * q_strideB;
|
query_ptr += batch_id * q_strideB;
|
||||||
key_ptr += batch_id * k_strideB;
|
key_ptr += batch_id * k_strideB;
|
||||||
value_ptr += batch_id * v_strideB;
|
value_ptr += batch_id * v_strideB;
|
||||||
output_ptr += batch_id * o_strideB;
|
output_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||||
if (output_accum_ptr != nullptr) {
|
if (output_accum_ptr != nullptr) {
|
||||||
output_accum_ptr += batch_id * o_strideB;
|
output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||||
}
|
}
|
||||||
q_start = 0;
|
q_start = 0;
|
||||||
k_start = 0;
|
k_start = 0;
|
||||||
@ -208,11 +199,11 @@ struct AttentionKernel {
|
|||||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||||
head_id * o_strideH;
|
head_id * head_dim_value;
|
||||||
|
|
||||||
if (output_accum_ptr != nullptr) {
|
if (output_accum_ptr != nullptr) {
|
||||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||||
head_id * o_strideH;
|
head_id * head_dim_value;
|
||||||
} else {
|
} else {
|
||||||
// Accumulate directly in the destination buffer (eg for f32)
|
// Accumulate directly in the destination buffer (eg for f32)
|
||||||
output_accum_ptr = (accum_t*)output_ptr;
|
output_accum_ptr = (accum_t*)output_ptr;
|
||||||
@ -652,7 +643,7 @@ struct AttentionKernel {
|
|||||||
warp_id(),
|
warp_id(),
|
||||||
p.num_keys - iter_key_start,
|
p.num_keys - iter_key_start,
|
||||||
iteratorC_tile_offset,
|
iteratorC_tile_offset,
|
||||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
p.scale);
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@ -858,93 +849,3 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|||||||
}
|
}
|
||||||
AK::attention_kernel(p);
|
AK::attention_kernel(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AK>
|
|
||||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|
||||||
attention_kernel_batched(typename AK::Params params);
|
|
||||||
|
|
||||||
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
|
|
||||||
template <> \
|
|
||||||
__global__ void __launch_bounds__( \
|
|
||||||
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
|
|
||||||
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
|
|
||||||
using Kernel = __VA_ARGS__;
|
|
||||||
#define _ATTENTION_KERNEL_FORWARD_END() }
|
|
||||||
|
|
||||||
#ifdef __CUDA_ARCH__
|
|
||||||
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
|
|
||||||
#else
|
|
||||||
#define __CUDA_ARCH_OR_ZERO__ 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
|
|
||||||
ARCH, \
|
|
||||||
SCALAR_T, \
|
|
||||||
IS_ALIGNED, \
|
|
||||||
QUERIES_PER_BLOCK, \
|
|
||||||
KEYS_PER_BLOCK, \
|
|
||||||
SINGLE_VALUE_ITER) \
|
|
||||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
|
||||||
SCALAR_T, \
|
|
||||||
cutlass::arch::Sm##ARCH, \
|
|
||||||
IS_ALIGNED, \
|
|
||||||
QUERIES_PER_BLOCK, \
|
|
||||||
KEYS_PER_BLOCK, \
|
|
||||||
SINGLE_VALUE_ITER>) \
|
|
||||||
if (!p.advance_to_block()) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
Kernel::attention_kernel(p); \
|
|
||||||
_ATTENTION_KERNEL_FORWARD_END();
|
|
||||||
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
|
|
||||||
ARCH, \
|
|
||||||
SCALAR_T, \
|
|
||||||
IS_ALIGNED, \
|
|
||||||
QUERIES_PER_BLOCK, \
|
|
||||||
KEYS_PER_BLOCK, \
|
|
||||||
SINGLE_VALUE_ITER) \
|
|
||||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
|
||||||
SCALAR_T, \
|
|
||||||
cutlass::arch::Sm##ARCH, \
|
|
||||||
IS_ALIGNED, \
|
|
||||||
QUERIES_PER_BLOCK, \
|
|
||||||
KEYS_PER_BLOCK, \
|
|
||||||
SINGLE_VALUE_ITER>) \
|
|
||||||
printf( \
|
|
||||||
"FATAL: this function is for sm%d, but was built for sm%d\n", \
|
|
||||||
int(ARCH), \
|
|
||||||
int(__CUDA_ARCH_OR_ZERO__)); \
|
|
||||||
_ATTENTION_KERNEL_FORWARD_END();
|
|
||||||
|
|
||||||
// All kernels are disabled by default
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
|
||||||
|
|
||||||
// Enable the right one based on __CUDA_ARCH__
|
|
||||||
#ifndef __CUDA_ARCH__
|
|
||||||
#elif __CUDA_ARCH__ < 500
|
|
||||||
#error "Need cuda arch at least 5.0"
|
|
||||||
#elif __CUDA_ARCH__ < 700
|
|
||||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
|
||||||
#elif __CUDA_ARCH__ < 750
|
|
||||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
|
||||||
#elif __CUDA_ARCH__ < 800
|
|
||||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
|
||||||
#elif __CUDA_ARCH__ >= 800
|
|
||||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
|
||||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
|
||||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
|
||||||
#endif
|
|
||||||
|
|||||||
@ -57,8 +57,8 @@
|
|||||||
#include "epilogue_thread_apply_logsumexp.h"
|
#include "epilogue_thread_apply_logsumexp.h"
|
||||||
#include "gemm_kernel_utils.h"
|
#include "gemm_kernel_utils.h"
|
||||||
#include "iterators/make_residual_last.h"
|
#include "iterators/make_residual_last.h"
|
||||||
|
#include "iterators/transpose_warp_iterator.h"
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
#include "iterators/warp_iterator_from_smem.h"
|
||||||
|
|
||||||
namespace cutlass {
|
namespace cutlass {
|
||||||
namespace gemm {
|
namespace gemm {
|
||||||
@ -560,20 +560,14 @@ template <
|
|||||||
typename Policy1_,
|
typename Policy1_,
|
||||||
/// Number of stages,
|
/// Number of stages,
|
||||||
int Stages_,
|
int Stages_,
|
||||||
|
int kMaxK_,
|
||||||
/// Used for partial specialization
|
/// Used for partial specialization
|
||||||
typename Enable = bool>
|
typename Enable = bool>
|
||||||
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
class MmaMultistageFromSharedMemory
|
||||||
Shape1_,
|
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
|
||||||
AccumulatorSharedStorage::Shape::kN,
|
|
||||||
Policy1_,
|
|
||||||
Stages_> {
|
|
||||||
public:
|
public:
|
||||||
///< Base class
|
///< Base class
|
||||||
using Base = MmaBaseFromSharedMemory<
|
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
|
||||||
Shape1_,
|
|
||||||
AccumulatorSharedStorage::Shape::kN,
|
|
||||||
Policy1_,
|
|
||||||
Stages_>;
|
|
||||||
|
|
||||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using Shape1 = Shape1_;
|
using Shape1 = Shape1_;
|
||||||
@ -1035,16 +1029,39 @@ template <
|
|||||||
typename WarpShape,
|
typename WarpShape,
|
||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename RegularWarpIterator,
|
typename RegularWarpIterator,
|
||||||
typename Policy>
|
typename Policy,
|
||||||
|
typename Enable = void>
|
||||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||||
|
|
||||||
// TensorOp - Ampere
|
// TensorOp - Ampere half
|
||||||
|
template <typename RegularWarpIterator, typename Policy>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||||
|
RegularWarpIterator,
|
||||||
|
Policy,
|
||||||
|
typename platform::enable_if<(
|
||||||
|
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||||
|
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||||
|
static constexpr auto kWarpSize = 32;
|
||||||
|
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||||
|
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||||
|
|
||||||
|
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||||
|
cutlass::gemm::Operand::kA,
|
||||||
|
typename RegularWarpIterator::Element>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TensorOp - Ampere f32
|
||||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||||
struct DefaultWarpIteratorAFromSharedMemory<
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
WarpShape,
|
WarpShape,
|
||||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||||
RegularWarpIterator,
|
RegularWarpIterator,
|
||||||
Policy> {
|
Policy,
|
||||||
|
typename platform::enable_if<(
|
||||||
|
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||||
|
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||||
static constexpr auto kWarpSize = 32;
|
static constexpr auto kWarpSize = 32;
|
||||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||||
@ -1099,7 +1116,10 @@ struct DefaultWarpIteratorAFromSharedMemory<
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Converts a "regular" Mma into their counterpart from shared memory
|
// Converts a "regular" Mma into their counterpart from shared memory
|
||||||
template <typename Mma_, typename AccumulatorSharedStorage>
|
template <
|
||||||
|
typename Mma_,
|
||||||
|
typename AccumulatorSharedStorage,
|
||||||
|
bool kTransposeA = false>
|
||||||
struct DefaultMmaFromSharedMemory;
|
struct DefaultMmaFromSharedMemory;
|
||||||
|
|
||||||
// Mma pipelined
|
// Mma pipelined
|
||||||
@ -1130,7 +1150,8 @@ template <
|
|||||||
typename TransformA_,
|
typename TransformA_,
|
||||||
/// Transformation applied to B operand
|
/// Transformation applied to B operand
|
||||||
typename TransformB_,
|
typename TransformB_,
|
||||||
typename AccumulatorSharedStorage_>
|
typename AccumulatorSharedStorage_,
|
||||||
|
bool kTransposeA>
|
||||||
struct DefaultMmaFromSharedMemory<
|
struct DefaultMmaFromSharedMemory<
|
||||||
MmaPipelined<
|
MmaPipelined<
|
||||||
Shape_,
|
Shape_,
|
||||||
@ -1143,7 +1164,8 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Policy_,
|
Policy_,
|
||||||
TransformA_,
|
TransformA_,
|
||||||
TransformB_>,
|
TransformB_>,
|
||||||
AccumulatorSharedStorage_> {
|
AccumulatorSharedStorage_,
|
||||||
|
kTransposeA> {
|
||||||
static constexpr int kWarpSize = 32;
|
static constexpr int kWarpSize = 32;
|
||||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
@ -1163,6 +1185,7 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||||
using ArchMmaOperator = typename Policy_::Operator;
|
using ArchMmaOperator = typename Policy_::Operator;
|
||||||
|
|
||||||
|
static constexpr bool kIsTransposedA = false;
|
||||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -1214,7 +1237,8 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
/// Use zfill or predicate for out-of-bound cp.async
|
/// Use zfill or predicate for out-of-bound cp.async
|
||||||
SharedMemoryClearOption SharedMemoryClear,
|
SharedMemoryClearOption SharedMemoryClear,
|
||||||
typename AccumulatorSharedStorage_>
|
typename AccumulatorSharedStorage_,
|
||||||
|
bool kTransposeA>
|
||||||
struct DefaultMmaFromSharedMemory<
|
struct DefaultMmaFromSharedMemory<
|
||||||
MmaMultistage<
|
MmaMultistage<
|
||||||
Shape_,
|
Shape_,
|
||||||
@ -1229,7 +1253,8 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Policy_,
|
Policy_,
|
||||||
Stages,
|
Stages,
|
||||||
SharedMemoryClear>,
|
SharedMemoryClear>,
|
||||||
AccumulatorSharedStorage_> {
|
AccumulatorSharedStorage_,
|
||||||
|
kTransposeA> {
|
||||||
static constexpr int kWarpSize = 32;
|
static constexpr int kWarpSize = 32;
|
||||||
|
|
||||||
using RegularMma = MmaMultistage<
|
using RegularMma = MmaMultistage<
|
||||||
@ -1248,13 +1273,22 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
|
|
||||||
using WarpShape = typename Policy_::Operator::Shape;
|
using WarpShape = typename Policy_::Operator::Shape;
|
||||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
typename RegularMma::Operator::IteratorA,
|
typename RegularMma::Operator::IteratorA,
|
||||||
Policy_>::WarpIterator;
|
Policy_>::WarpIterator;
|
||||||
|
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
||||||
|
static constexpr bool kIsTransposedA =
|
||||||
|
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
||||||
|
using WarpIteratorA = typename platform::conditional<
|
||||||
|
kIsTransposedA,
|
||||||
|
typename WarpIteratorTranspose::Iterator,
|
||||||
|
WarpIteratorA_>::type;
|
||||||
|
|
||||||
static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN;
|
static int constexpr kMaxK = kIsTransposedA
|
||||||
|
? AccumulatorSharedStorage_::Shape::kM
|
||||||
|
: AccumulatorSharedStorage_::Shape::kN;
|
||||||
// Reduce the number of stages if we don't need that many
|
// Reduce the number of stages if we don't need that many
|
||||||
static int constexpr kStagesMax =
|
static int constexpr kStagesMax =
|
||||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||||
@ -1274,7 +1308,8 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
ElementC_,
|
ElementC_,
|
||||||
LayoutC_,
|
LayoutC_,
|
||||||
Policy_,
|
Policy_,
|
||||||
kStages>;
|
kStages,
|
||||||
|
kMaxK>;
|
||||||
};
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user