249 lines
13 KiB
C++
249 lines
13 KiB
C++
|
|
#pragma once
|
|
|
|
#include <cute/tensor.hpp>
|
|
|
|
#include <cutlass/cutlass.h>
|
|
#include "cutlass/layout/layout.h"
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
#include "kernel_traits.h"
|
|
#include "utils.h"
|
|
|
|
namespace flash {
|
|
|
|
using namespace cute;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <class Element, class SmemShape, class SmemShapeMaxSplits>
|
|
struct SharedStorageLSE {
|
|
cute::array_aligned<Element, cute::size_v<SmemShape>> smem_lse;
|
|
cute::array_aligned<bool, cute::size_v<SmemShapeMaxSplits>> smem_valid_splits;
|
|
};
|
|
|
|
// DONT use Kernel_traits here to avoid redundant compilation.
|
|
// template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
|
|
template<typename Element, typename ElementAccum, int kHeadDim, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
|
|
__global__ void combine_attn_seqk_parallel(Params const params) {
|
|
// using Element = typename Kernel_traits::OutputType;
|
|
// using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
using index_t = int64_t; // Kernel_traits::index_t
|
|
constexpr int kMaxSplits = 1 << Log_max_splits;
|
|
// constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
constexpr int kNThreads = 128; //Kernel_traits::kNThreads;
|
|
|
|
static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
|
|
static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
|
|
static_assert(kNThreads == 128, "We assume that each block has 128 threads");
|
|
|
|
// Shared memory.
|
|
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
|
|
//__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1];
|
|
extern __shared__ char smem_[];
|
|
using SharedStorage = SharedStorageLSE<ElementAccum, Shape<Int<kMaxSplits>, Int<kBlockM+1>>, Shape<Int<kMaxSplits>>>;
|
|
SharedStorage &shared_storage =
|
|
*reinterpret_cast<SharedStorage *>(smem_);
|
|
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kMaxSplits>, Int<kBlockM+1>>{});
|
|
Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape<Int<kMaxSplits>>{});
|
|
|
|
// The thread and block index.
|
|
const int tidx = threadIdx.x;
|
|
const int bidx = blockIdx.x;
|
|
|
|
const index_t lse_size = params.b * params.h * params.seqlen_q;
|
|
//if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
|
|
|
|
const index_t row_offset_lse = bidx * kBlockM;
|
|
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
|
|
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
|
|
make_stride(lse_size, _1{}));
|
|
|
|
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
|
|
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
|
|
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
|
|
Layout flat_layout = make_layout(lse_size);
|
|
Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
|
|
auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
|
|
Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
|
|
Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
|
|
|
|
Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
|
|
|
|
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
|
|
|
|
// Read the LSE values from gmem and store them in shared memory, then transpose them.
|
|
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
|
|
#pragma unroll
|
|
for (int l = 0; l < kNLsePerThread; ++l) {
|
|
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
|
|
const int col = tidx % kBlockM;
|
|
ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
|
if (row < kMaxSplits) { sLSE(row,col) = lse; }
|
|
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
|
|
}
|
|
__syncthreads();
|
|
|
|
// Reduce along the kBlockM dimension to determine valid splits (store in SMEM)
|
|
// One thread per split. Know NumThreads = 128 >= NumMaxSplits
|
|
if (tidx < kMaxSplits) {
|
|
bool is_valid_split = false;
|
|
#pragma unroll
|
|
for (int col = 0; col < kBlockM; ++col) {
|
|
if(sLSE(tidx,col) != -INFINITY) {
|
|
is_valid_split = true;
|
|
}
|
|
}
|
|
sValidSplits(tidx) = is_valid_split;
|
|
}
|
|
__syncthreads();
|
|
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
|
|
|
|
Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
|
|
constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
|
|
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
|
|
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
|
|
// kBlockM rows, so each time we load we can load 128 / kBlockM rows).
|
|
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
|
|
// static_assert(kThreadsPerSplit <= 32);
|
|
static_assert(kRowsPerLoadTranspose <= 32);
|
|
static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
|
|
#pragma unroll
|
|
for (int l = 0; l < kNLsePerThread; ++l) {
|
|
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
|
const int col = tidx / kRowsPerLoadTranspose;
|
|
//if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
|
|
lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY;
|
|
|
|
}
|
|
//return;
|
|
|
|
// Compute the logsumexp of the LSE along the split dimension.
|
|
ElementAccum lse_max = lse_accum(0);
|
|
#pragma unroll
|
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
|
|
MaxOp<float> max_op;
|
|
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
|
lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
|
float lse_sum = expf(lse_accum(0) - lse_max);
|
|
#pragma unroll
|
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|
|
SumOp<float> sum_op;
|
|
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
|
|
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
|
|
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
|
|
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
|
|
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
|
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
|
|
if (params.unpadded_lse) {
|
|
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
|
|
if (lse_offset < lse_size) {
|
|
gLSE_unpadded(lse_offset) = lse_logsum;
|
|
}
|
|
} else {
|
|
gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
|
|
}
|
|
}
|
|
//if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum);
|
|
|
|
// Store the scales exp(lse - lse_logsum) in shared memory.
|
|
#pragma unroll
|
|
for (int l = 0; l < kNLsePerThread; ++l) {
|
|
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
|
const int col = tidx / kRowsPerLoadTranspose;
|
|
if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); }
|
|
}
|
|
__syncthreads();
|
|
|
|
const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
|
|
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
Stride<Int<kHeadDim>, _1>{});
|
|
constexpr int kBlockN = kNThreads / kBlockM;
|
|
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
|
|
using GmemTiledCopyOaccum = decltype(
|
|
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
|
GmemLayoutAtomOaccum{},
|
|
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
|
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
|
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
|
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
clear(tOrO);
|
|
|
|
// Predicates
|
|
Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
|
//if (cute::thread0()) print_tensor (cOaccum);
|
|
// Repeat the partitioning with identity layouts
|
|
Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
|
|
Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
|
if (!Is_even_K) {
|
|
#pragma unroll
|
|
for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
|
|
}
|
|
// Load Oaccum in then scale and accumulate to O
|
|
for (int split = 0; split < params.num_splits; ++split) {
|
|
// DONT copy in Oaccum if lse(split) = -inf for all kBlockM.
|
|
if(sValidSplits(split)) {
|
|
flash::copy</*Is_even_MN=*/false, Is_even_K>(
|
|
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
|
|
);
|
|
#pragma unroll
|
|
for (int m = 0; m < size<1>(tOrOaccum); ++m) {
|
|
int row = get<0>(tOcOaccum(0, m, 0));
|
|
ElementAccum lse_scale = sLSE(split,row);
|
|
if (lse_scale != 0.f) {
|
|
#pragma unroll
|
|
for (int k = 0; k < size<2>(tOrOaccum); ++k) {
|
|
#pragma unroll
|
|
for (int i = 0; i < size<0>(tOrOaccum); ++i) {
|
|
tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
|
|
//tOrO(i, m, k) += tOrOaccum(i, m, k);
|
|
}
|
|
}
|
|
}
|
|
//if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); }
|
|
}
|
|
}
|
|
tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
|
|
}
|
|
//if (cute::thread0()) { print_tensor(tOrO); }
|
|
|
|
Tensor rO = flash::convert_type<Element>(tOrO);
|
|
// Write to gO
|
|
#pragma unroll
|
|
for (int m = 0; m < size<1>(rO); ++m) {
|
|
const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
|
|
//if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
|
|
if (idx < params.b * params.h * params.seqlen_q) {
|
|
//print ("final2\n");
|
|
const int batch_idx = idx / (params.h * params.seqlen_q);
|
|
const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
|
|
// The index to the rows of Q
|
|
const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
|
|
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
|
|
+ head_idx * params.o_head_stride + row * params.o_row_stride;
|
|
#pragma unroll
|
|
for (int k = 0; k < size<2>(rO); ++k) {
|
|
if (Is_even_K || tOpOaccum(k)) {
|
|
const int col = get<1>(tOcOaccum(0, m, k));
|
|
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
|
|
Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
|
// TODO: Should check if this is using vectorized store, but it seems pretty fast
|
|
copy(rO(_, m, k), gO);
|
|
//if (cute::thread0()) { print ("final\n"); print_tensor(gO); }
|
|
// if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
|
|
// reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|