diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h index 73a258ef..7e91a723 100644 --- a/examples/41_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -47,19 +47,52 @@ } // Print on the first thread of the first block -#if 0 +#if 1 #define PRINT_WARP_ID 0 #define PRINT_LANE_ID 0 #define PRINT_T0_L0(msg, ...) \ if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ threadIdx.z == 0) { \ - printf(msg "\n", __VA_ARGS__); \ + printf(msg "\n", ##__VA_ARGS__); \ } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_T0_L0 +#define PRINT_TX_LX +#endif + struct __string_view { char const* data; std::size_t size; }; +#if __cplusplus >= 201402L template constexpr __string_view __get_type_name() { char const* p = __PRETTY_FUNCTION__; @@ -83,7 +116,10 @@ constexpr __string_view __get_type_name() { return {}; } #else -#define PRINT_T0_L0 +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} #endif // Print a given array diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index 72015996..58f47d74 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -168,6 +168,9 @@ public: typename LayoutP::Stride::LongIndex *ldv; typename LayoutO::Stride::LongIndex *ldo; + // Scale + ElementAccumulator scale; + // Whether causal masking is to be performed bool causal; @@ -193,6 +196,7 @@ public: ldk(nullptr), ldv(nullptr), ldo(nullptr), + scale(0), causal(false), host_problem_sizes(nullptr) { @@ -218,6 +222,7 @@ public: typename LayoutV::Stride::LongIndex *ldv, typename LayoutO::Stride::LongIndex *ldo, bool causal, + ElementAccumulator scale, GemmCoord *host_problem_sizes=nullptr ): problem_sizes0(problem_sizes0), @@ -235,6 +240,7 @@ public: ldv(ldv), ldo(ldo), causal(causal), + scale(scale), host_problem_sizes(host_problem_sizes) { @@ -273,6 +279,7 @@ public: typename LayoutP::Stride::LongIndex *ldv; typename LayoutO::Stride::LongIndex *ldo; + ElementAccumulator scale; bool causal; // @@ -291,7 +298,8 @@ public: ldk(nullptr), ldv(nullptr), ldo(nullptr), - causal(false) + causal(false), + scale(0) { } CUTLASS_HOST_DEVICE @@ -310,8 +318,9 @@ public: ldk(args.ldk), ldv(args.ldv), ldo(args.ldo), - causal(args.causal) - { + causal(args.causal), + scale(args.scale) + { } @@ -337,6 +346,7 @@ public: ldv = args.ldv; ldo = args.ldo; causal = args.causal; + scale = args.scale; } }; @@ -649,7 +659,7 @@ public: warp_id(), num_keys - iter_key_start, iteratorC_tile_offset, - 1.0f / cutlass::fast_sqrt(float(problem_size0.k()))); + params.scale); })); })); diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index 53af4ac1..45d6813a 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -504,37 +504,51 @@ private: ldo_host.resize(problem_count()); seqlen_host.resize(problem_count()); - for (int32_t i = 0; i < problem_count(); ++i) { + // Create tensors in BMHK format, where + // B = batch_size + // M = sequence length + // H = num_heads + // K = embedding size per head + int64_t batch_offset_Q, batch_offset_K, batch_offset_V, batch_offset_O; - auto problem0 = options.problem_sizes0.at(i); - auto problem1 = options.problem_sizes1.at(i); + for (int32_t b = 0; b < options.batch_size; ++b) { + batch_offset_Q = total_elements_Q; + batch_offset_K = total_elements_K; + batch_offset_V = total_elements_V; + batch_offset_O = total_elements_O; + for (int32_t h = 0; h < options.head_number; ++h) { + int32_t i = h + b * options.head_number; - ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); - ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); - ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); - ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); - ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); + auto problem0 = options.problem_sizes0.at(i); + auto problem1 = options.problem_sizes1.at(i); - // m = n for attention problems. - seqlen_host.at(i) = problem0.m(); + ldq_host.at(i) = LayoutQ::packed({problem0.m(), options.head_number * problem0.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({options.head_number * problem0.k(), problem0.n()}).stride(0); + ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); + ldv_host.at(i) = LayoutV::packed({problem1.k(), options.head_number * problem1.n()}).stride(0); + ldo_host.at(i) = LayoutO::packed({problem1.m(), options.head_number * problem1.n()}).stride(0); - offset_Q.push_back(total_elements_Q); - offset_K.push_back(total_elements_K); - offset_P.push_back(total_elements_P); - offset_V.push_back(total_elements_V); - offset_O.push_back(total_elements_O); + // m = n for attention problems. + seqlen_host.at(i) = problem0.m(); - int64_t elements_Q = problem0.m() * problem0.k(); - int64_t elements_K = problem0.k() * problem0.n(); - int64_t elements_P = problem0.m() * problem0.n(); - int64_t elements_V = problem1.k() * problem1.n(); - int64_t elements_O = problem1.m() * problem1.n(); + offset_Q.push_back(batch_offset_Q + h * problem0.k()); + offset_K.push_back(batch_offset_K + h * problem0.k()); + offset_P.push_back(total_elements_P); + offset_V.push_back(batch_offset_V + h * problem0.k()); + offset_O.push_back(batch_offset_O + h * problem1.n()); - total_elements_Q += elements_Q; - total_elements_K += elements_K; - total_elements_P += elements_P; - total_elements_V += elements_V; - total_elements_O += elements_O; + int64_t elements_Q = problem0.m() * problem0.k(); + int64_t elements_K = problem0.k() * problem0.n(); + int64_t elements_P = problem0.m() * problem0.n(); + int64_t elements_V = problem1.k() * problem1.n(); + int64_t elements_O = problem1.m() * problem1.n(); + + total_elements_Q += elements_Q; + total_elements_K += elements_K; + total_elements_P += elements_P; + total_elements_V += elements_V; + total_elements_O += elements_O; + } } problem_sizes_device0.reset(problem_count()); @@ -649,15 +663,11 @@ private: bool passed = true; - for (int32_t i = 0; i < problem_count(); ++i) { - cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); - cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); - - LayoutQ layout_Q(ldq_host.at(i)); - LayoutK layout_K(ldk_host.at(i)); - LayoutP layout_P(ldp_host.at(i)); - LayoutV layout_V(ldv_host.at(i)); - LayoutO layout_O(ldo_host.at(i)); + for (int32_t b = 0; b < options.batch_size; ++b) { + int32_t i = b * options.head_number; + // Problem size is the same for all heads + cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(b * options.head_number); + cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(b * options.head_number); MatrixCoord extent_Q{problem0.m(), problem0.k()}; MatrixCoord extent_K{problem0.k(), problem0.n()}; @@ -665,114 +675,121 @@ private: MatrixCoord extent_V{problem1.k(), problem1.n()}; MatrixCoord extent_O{problem1.m(), problem1.n()}; - cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); - cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); - cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); - cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); - - cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); - cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); - + LayoutO layout_O(ldo_host.at(i)); + std::vector matrix_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); - cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); - // Reference GEMM - cutlass::reference::device::GemmComplex< - ElementQ, LayoutQ, - ElementK, LayoutK, - ElementP, LayoutP, - ElementCompute, ElementAccumulator - >( - problem0, - ElementAccumulator(options.alpha0), - view_Q, - Attention::MM0::Mma::kTransformA, - view_K, - Attention::MM0::Mma::kTransformB, - ElementAccumulator(options.beta), - view_P, - view_Ref_device, - ElementAccumulator(0) - ); + for (int32_t h = 0; h < options.head_number; ++h) { + i = h + b * options.head_number; - // Compute softmax for P. We need to explicitly compute softmax - // over P because softmax is fused to the second GEMM in the - // profiled implementation. - std::vector matrix_Ref(layout_P.capacity(extent_P)); - cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); - cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); - std::vector vector_Norm_Ref(problem0.m()); - std::vector vector_Sum_Ref(problem0.m()); + LayoutQ layout_Q(ldq_host.at(i)); + LayoutK layout_K(ldk_host.at(i)); + LayoutP layout_P(ldp_host.at(i)); + LayoutV layout_V(ldv_host.at(i)); - int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); + cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); + cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); + cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); + cutlass::TensorView view_Ref_O_device(block_Ref_O.get() + offset_O.at(i) - offset_O.at(b * options.head_number), layout_O, extent_O); - // Compute softmax for referece matrix - for (int m = 0; m < problem0.m(); m++) { - int n_dim_row = n_dim; - if (options.causal) { - n_dim_row = std::min(m + 1, n_dim); - } - ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); - for (int n = 1; n < n_dim_row; n++) { - max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); - } + cutlass::DeviceAllocation block_Ref_P(layout_P.capacity(extent_P)); + cutlass::TensorView view_Ref_P_device(block_Ref_P.get(), layout_P, extent_P); - vector_Norm_Ref.at(m) = ElementNorm(max); + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementQ, LayoutQ, + ElementK, LayoutK, + ElementP, LayoutP, + ElementCompute, ElementAccumulator + >( + problem0, + ElementAccumulator(options.alpha0), + view_Q, + Attention::MM0::Mma::kTransformA, + view_K, + Attention::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_Ref_P_device, + view_Ref_P_device, + ElementAccumulator(0) + ); - ElementSoftmaxCompute sum = ElementSoftmaxCompute(); - for (int n = 0; n < n_dim_row; n++) { - sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); - } - ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); + // Compute softmax for P. We need to explicitly compute softmax + // over P because softmax is fused to the second GEMM in the + // profiled implementation. + std::vector matrix_Ref(layout_P.capacity(extent_P)); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref_P.get(), matrix_Ref.size()); + cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); + std::vector vector_Norm_Ref(problem0.m()); + std::vector vector_Sum_Ref(problem0.m()); - vector_Sum_Ref.at(m) = ElementSum(inv_sum); + int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); - for (int n = 0; n < n_dim_row; n++) { - view_Ref_host.ref().at({m, n}) = ElementP( - std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum - ); - } - // Mask out the rest of the attention matrix - for (int n = n_dim_row; n < n_dim; ++n) { - view_Ref_host.ref().at({m, n}) = ElementP(0); - } - } - - // when not using mask, problem_real and problem share the same sizes - if (options.use_mask) { + // Compute softmax for reference matrix for (int m = 0; m < problem0.m(); m++) { - for (int n = n_dim; n < problem0.n(); n++) { + int n_dim_row = n_dim; + if (options.causal) { + n_dim_row = std::min(m + 1, n_dim); + } + ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); + for (int n = 1; n < n_dim_row; n++) { + max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); + } + + vector_Norm_Ref.at(m) = ElementNorm(max); + + ElementSoftmaxCompute sum = ElementSoftmaxCompute(); + for (int n = 0; n < n_dim_row; n++) { + sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); + } + ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); + + vector_Sum_Ref.at(m) = ElementSum(inv_sum); + + for (int n = 0; n < n_dim_row; n++) { + view_Ref_host.ref().at({m, n}) = ElementP( + std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum + ); + } + // Mask out the rest of the attention matrix + for (int n = n_dim_row; n < n_dim; ++n) { view_Ref_host.ref().at({m, n}) = ElementP(0); } } + + // when not using mask, problem_real and problem share the same sizes + if (options.use_mask) { + for (int m = 0; m < problem0.m(); m++) { + for (int n = n_dim; n < problem0.n(); n++) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + } + } + + cutlass::device_memory::copy_to_device(block_Ref_P.get(), matrix_Ref.data(), matrix_Ref.size()); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementP, LayoutP, + ElementV, LayoutV, + ElementO, LayoutO, + ElementCompute, ElementAccumulator + >( + problem1, + ElementAccumulator(options.alpha1), + view_Ref_P_device, + Attention::MM0::Mma::kTransformA, + view_V, + Attention::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_Ref_O_device, + view_Ref_O_device, + ElementAccumulator(0) + ); } - cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); - - // Reference GEMM - cutlass::reference::device::GemmComplex< - ElementP, LayoutP, - ElementV, LayoutV, - ElementO, LayoutO, - ElementCompute, ElementAccumulator - >( - problem1, - ElementAccumulator(options.alpha1), - view_P, - Attention::MM0::Mma::kTransformA, - view_V, - Attention::MM0::Mma::kTransformB, - ElementAccumulator(options.beta), - view_Ref_O_device, - view_Ref_O_device, - ElementAccumulator(0) - ); - // Copy to host memory - cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); - - std::vector matrix_O(layout_O.capacity(extent_O)); - cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); std::vector matrix_Ref_O(layout_O.capacity(extent_O)); cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); @@ -788,7 +805,7 @@ private: passed = passed && verified_O; if (!passed) { - std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + std::cerr << "\n***\nError - problem " << i << " (batch " << b << ") failed the QA check\n***\n" << std::endl; if (!verified_O) { std::cout << "Final matrix output is incorrect" << std::endl; @@ -831,6 +848,8 @@ public: // p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); // } + p.scale = options.alpha0; + p.num_heads = options.head_number; p.num_batches = options.batch_size; p.head_dim = options.head_size; @@ -839,18 +858,16 @@ public: p.num_keys = options.seq_length_kv; p.causal = options.causal; - // TODO: This might overflow for big tensors + // All tensors are in BMHK shapes + p.q_strideH = options.head_size; + p.k_strideH = options.head_size; + p.v_strideH = options.head_size_v; p.q_strideM = int32_t(ldq_host[0]); p.k_strideM = int32_t(ldk_host[0]); p.v_strideM = int32_t(ldv_host[0]); - p.q_strideH = p.q_strideM * options.seq_length; - p.k_strideH = p.k_strideM * options.seq_length_kv; - p.v_strideH = p.v_strideM * options.seq_length_kv; - p.o_strideH = options.head_size_v * options.seq_length; - p.q_strideB = p.q_strideH * options.head_number; - p.k_strideB = p.k_strideH * options.head_number; - p.v_strideB = p.v_strideH * options.head_number; - p.o_strideB = options.head_size_v * options.seq_length * options.head_number; + p.q_strideB = p.q_strideM * options.seq_length; + p.k_strideB = p.k_strideM * options.seq_length_kv; + p.v_strideB = p.v_strideM * options.seq_length_kv; } // launch kernel :) diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index 35b5c320..f3e2879f 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -921,6 +921,7 @@ public: ldv.get(), ldo.get(), options.causal, + options.alpha0, options.problem_sizes1.data() ); diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h index c2f2caf7..3fe57f00 100644 --- a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -36,20 +36,20 @@ //////////////////////////////////////////////////////////////////////////////// // Some helper functions //////////////////////////////////////////////////////////////////////////////// -#define DISPATCH_TYPES(tensor, func) \ - { \ - if (query.scalar_type() == at::ScalarType::Float) { \ - using scalar_t = float; \ - func(); \ - } else if (query.scalar_type() == at::ScalarType::Half) { \ - using scalar_t = cutlass::half_t; \ - func(); \ - } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ - using scalar_t = cutlass::bfloat16_t; \ - func(); \ - } else { \ - TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ - } \ +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ } #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ @@ -77,26 +77,27 @@ using ArchTag = cutlass::arch::Sm50; \ func(); \ } else { \ - TORCH_CHECK( \ + XFORMERS_CHECK( \ false, \ "Your device is too old. We require compute capability >= 50"); \ } \ } -#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK(TENSOR.is_contiguous()); +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); -#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK( \ +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); -#ifdef HAS_PYTORCH +#ifdef TORCH_CHECK #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ - TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") #define XFORMERS_CHECK TORCH_CHECK #elif defined(__CUDACC_RTC__) #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ @@ -108,6 +109,7 @@ return false; \ } #else +#include #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ std::cerr << #PTR " is not correctly aligned\n"; \ @@ -120,74 +122,25 @@ } #endif -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK( \ - B < cutlass::platform::numeric_limits::max(), \ - #B " overflows"); \ +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ } namespace gemm_kernel_utils { -#ifdef HAS_PYTORCH -template -struct TypeTraits; - -template <> -struct TypeTraits { - using scalar_t = cutlass::half_t; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::Half; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return at::PackedTensorAccessor32( - (scalar_t*)(tensor.data_ptr()), - tensor.sizes().data(), - tensor.strides().data()); - } -}; - -template <> -struct TypeTraits { - using scalar_t = cutlass::bfloat16_t; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::BFloat16; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return at::PackedTensorAccessor32( - (scalar_t*)(tensor.data_ptr()), - tensor.sizes().data(), - tensor.strides().data()); - } -}; - -template <> -struct TypeTraits { - using scalar_t = float; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::Float; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return tensor.packed_accessor32(); - } -}; -#endif - template constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { return (n + m - 1) / m; } +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + //////////////////////////////////////////////////////////////////////////////// // Determine the type of GEMM we do (TensorCores or not, Shapes ...) // TODO: Maybe we could rely on Cutlass's DefaultGemm templates diff --git a/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h index 298876ea..44f38dbc 100644 --- a/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h +++ b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -311,9 +311,9 @@ class PredicatedTileIteratorPrefetch { // on windows using unsigned long here gives the error // error: asm operand type size(4) does not match // type/size implied by constraint 'l' - uint64_t addr = (uint64_t)( - (void*)&memory_pointer - [column * ThreadMap::Delta::kColumn / kElementsPerAccess]); + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); } diff --git a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 00000000..37c42ea2 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp::WarpIteratorFromSmem> { + using Iterator = + cutlass::gemm::warp::WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 00000000..37f41699 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = cutlass::MatrixShape<16, 8>; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; + ++inst_m_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; + ++access_m_idx) { + int access_idx = access_m_idx + + kTilesPerInstruction * + (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + } else { + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[0], ref_.data() + ref_.offset(offset)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 6cb292c0..48c47edb 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -29,15 +29,6 @@ * **************************************************************************************************/ -#pragma once - -#ifdef HAS_PYTORCH -#include -#include -#include -#include -#endif - #include #include @@ -137,6 +128,9 @@ struct AttentionKernel { output_accum_ptr; // [num_queries, num_heads, head_dim_value] lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + // Scale + accum_t scale; + // Dimensions/strides int32_t head_dim; int32_t head_dim_value; @@ -154,18 +148,15 @@ struct AttentionKernel { int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; - int32_t o_strideH; int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; - int64_t o_strideB; int32_t num_batches; int32_t num_heads; CUTLASS_HOST_DEVICE int32_t o_strideM() const { - return head_dim_value; + return head_dim_value * num_heads; } - // Moves pointers to what we should process // Returns "false" if there is no work to do CUTLASS_DEVICE bool advance_to_block() { @@ -195,9 +186,9 @@ struct AttentionKernel { query_ptr += batch_id * q_strideB; key_ptr += batch_id * k_strideB; value_ptr += batch_id * v_strideB; - output_ptr += batch_id * o_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM(); if (output_accum_ptr != nullptr) { - output_accum_ptr += batch_id * o_strideB; + output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM(); } q_start = 0; k_start = 0; @@ -208,11 +199,11 @@ struct AttentionKernel { key_ptr += k_start * k_strideM + head_id * k_strideH; value_ptr += k_start * v_strideM + head_id * v_strideH; output_ptr += int64_t(q_start + query_start) * o_strideM() + - head_id * o_strideH; + head_id * head_dim_value; if (output_accum_ptr != nullptr) { output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + - head_id * o_strideH; + head_id * head_dim_value; } else { // Accumulate directly in the destination buffer (eg for f32) output_accum_ptr = (accum_t*)output_ptr; @@ -652,7 +643,7 @@ struct AttentionKernel { warp_id(), p.num_keys - iter_key_start, iteratorC_tile_offset, - 1.0f / cutlass::fast_sqrt(float(p.head_dim))); + p.scale); })); })); @@ -858,93 +849,3 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) } AK::attention_kernel(p); } - -template -__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) - attention_kernel_batched(typename AK::Params params); - -#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ - template <> \ - __global__ void __launch_bounds__( \ - __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ - attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ - using Kernel = __VA_ARGS__; -#define _ATTENTION_KERNEL_FORWARD_END() } - -#ifdef __CUDA_ARCH__ -#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -#else -#define __CUDA_ARCH_OR_ZERO__ 0 -#endif - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - if (!p.advance_to_block()) { \ - return; \ - } \ - Kernel::attention_kernel(p); \ - _ATTENTION_KERNEL_FORWARD_END(); - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - printf( \ - "FATAL: this function is for sm%d, but was built for sm%d\n", \ - int(ARCH), \ - int(__CUDA_ARCH_OR_ZERO__)); \ - _ATTENTION_KERNEL_FORWARD_END(); - -// All kernels are disabled by default -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) - -// Enable the right one based on __CUDA_ARCH__ -#ifndef __CUDA_ARCH__ -#elif __CUDA_ARCH__ < 500 -#error "Need cuda arch at least 5.0" -#elif __CUDA_ARCH__ < 700 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) -#elif __CUDA_ARCH__ < 750 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) -#elif __CUDA_ARCH__ < 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) -#elif __CUDA_ARCH__ >= 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) -#endif diff --git a/examples/41_fused_multi_head_attention/mma_from_smem.h b/examples/41_fused_multi_head_attention/mma_from_smem.h index 21ac4d10..d2ceaf02 100644 --- a/examples/41_fused_multi_head_attention/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/mma_from_smem.h @@ -57,8 +57,8 @@ #include "epilogue_thread_apply_logsumexp.h" #include "gemm_kernel_utils.h" #include "iterators/make_residual_last.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// +#include "iterators/transpose_warp_iterator.h" +#include "iterators/warp_iterator_from_smem.h" namespace cutlass { namespace gemm { @@ -560,20 +560,14 @@ template < typename Policy1_, /// Number of stages, int Stages_, + int kMaxK_, /// Used for partial specialization typename Enable = bool> -class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< - Shape1_, - AccumulatorSharedStorage::Shape::kN, - Policy1_, - Stages_> { +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { public: ///< Base class - using Base = MmaBaseFromSharedMemory< - Shape1_, - AccumulatorSharedStorage::Shape::kN, - Policy1_, - Stages_>; + using Base = MmaBaseFromSharedMemory; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape1 = Shape1_; @@ -1035,16 +1029,39 @@ template < typename WarpShape, typename InstructionShape, typename RegularWarpIterator, - typename Policy> + typename Policy, + typename Enable = void> struct DefaultWarpIteratorAFromSharedMemory {}; -// TensorOp - Ampere +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element>; +}; + +// TensorOp - Ampere f32 template struct DefaultWarpIteratorAFromSharedMemory< WarpShape, cutlass::gemm::GemmShape<16, 8, 8>, RegularWarpIterator, - Policy> { + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr auto kWarpSize = 32; using OpDelta = typename Policy::Operator::Policy::OpDelta; @@ -1099,7 +1116,10 @@ struct DefaultWarpIteratorAFromSharedMemory< }; // Converts a "regular" Mma into their counterpart from shared memory -template +template < + typename Mma_, + typename AccumulatorSharedStorage, + bool kTransposeA = false> struct DefaultMmaFromSharedMemory; // Mma pipelined @@ -1130,7 +1150,8 @@ template < typename TransformA_, /// Transformation applied to B operand typename TransformB_, - typename AccumulatorSharedStorage_> + typename AccumulatorSharedStorage_, + bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaPipelined< Shape_, @@ -1143,7 +1164,8 @@ struct DefaultMmaFromSharedMemory< Policy_, TransformA_, TransformB_>, - AccumulatorSharedStorage_> { + AccumulatorSharedStorage_, + kTransposeA> { static constexpr int kWarpSize = 32; using SmemAccumulatorLayout = cutlass::layout::RowMajor; @@ -1163,6 +1185,7 @@ struct DefaultMmaFromSharedMemory< using InstructionShape = typename Policy_::Operator::InstructionShape; using ArchMmaOperator = typename Policy_::Operator; + static constexpr bool kIsTransposedA = false; using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< WarpShape, InstructionShape, @@ -1214,7 +1237,8 @@ template < int Stages, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear, - typename AccumulatorSharedStorage_> + typename AccumulatorSharedStorage_, + bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaMultistage< Shape_, @@ -1229,7 +1253,8 @@ struct DefaultMmaFromSharedMemory< Policy_, Stages, SharedMemoryClear>, - AccumulatorSharedStorage_> { + AccumulatorSharedStorage_, + kTransposeA> { static constexpr int kWarpSize = 32; using RegularMma = MmaMultistage< @@ -1248,13 +1273,22 @@ struct DefaultMmaFromSharedMemory< using WarpShape = typename Policy_::Operator::Shape; using InstructionShape = typename Policy_::Operator::InstructionShape; - using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< + using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory< WarpShape, InstructionShape, typename RegularMma::Operator::IteratorA, Policy_>::WarpIterator; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; - static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN; + static int constexpr kMaxK = kIsTransposedA + ? AccumulatorSharedStorage_::Shape::kM + : AccumulatorSharedStorage_::Shape::kN; // Reduce the number of stages if we don't need that many static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); @@ -1274,7 +1308,8 @@ struct DefaultMmaFromSharedMemory< ElementC_, LayoutC_, Policy_, - kStages>; + kStages, + kMaxK>; }; /////////////////////////////////////////////////////////////////////////////////////////////////