cutlass/examples/42_fused_multi_head_attention/gemm/custom_mma.h
dan_the_3rd 4db6a6140e
ex42: Fused MHA imported from xFormers (#662)
* ex42: Fused MHA imported from xFormers

* Remove std:: references

* Support K>128 in the example

* Support causal option

* Support different head size for V, and different seqlength for KV

* Update FLOPS counter

* Remove bit_cast

* fix build: Replace M_LOG2E

* Add doc

* Revert "Remove bit_cast"

This reverts commit 9662fa86bb7c57c1a015ac0bf52cb52940fbbf80.

* Explicit casts to int32_t for windows build

Co-authored-by: danthe3rd <danthe3rd>
2022-10-17 10:49:33 -04:00

94 lines
2.1 KiB
C++

#pragma once
#include "custom_mma_multistage.h"
#include "custom_mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
template <typename Mma, int kMaxK>
struct MakeCustomMma;
template <
typename Shape,
typename IteratorA,
typename SmemIteratorA,
cutlass::arch::CacheOperation::Kind CacheOpA,
typename IteratorB,
typename SmemIteratorB,
cutlass::arch::CacheOperation::Kind CacheOpB,
typename ElementC,
typename LayoutC,
typename Policy,
int Stages,
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaMultistage<
Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
Stages,
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: cutlass::const_min(
Stages,
(kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
kStages,
SharedMemoryClear,
kMaxK>;
};
template <
typename Shape,
typename IteratorA,
typename SmemIteratorA,
typename IteratorB,
typename SmemIteratorB,
typename ElementC,
typename LayoutC,
typename Policy,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaPipelined<
Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>,
kMaxK> {
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<
Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>;
};