[Kernel] Add GPU architecture guards to the CUTLASS w8a8 kernels to reduce binary size (#5157)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
02cc3b51a7
commit
ccd4f129e8
@ -48,9 +48,44 @@ using namespace cute;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename Arch, typename ElementAB_, typename ElementD_,
|
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||||
typename TileShape, typename WarpShape, typename InstructionShape,
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
int32_t MainLoopStages>
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm75_to_sm80 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm80_to_sm89 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm89_to_sm90 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Arch, template <typename> typename ArchGuard,
|
||||||
|
typename ElementAB_, typename ElementD_, typename TileShape,
|
||||||
|
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||||
struct cutlass_2x_gemm {
|
struct cutlass_2x_gemm {
|
||||||
using ElementAB = ElementAB_;
|
using ElementAB = ElementAB_;
|
||||||
using ElementD = ElementD_;
|
using ElementD = ElementD_;
|
||||||
@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
|
|||||||
using RowMajor = typename cutlass::layout::RowMajor;
|
using RowMajor = typename cutlass::layout::RowMajor;
|
||||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||||
using KernelType =
|
using KernelType =
|
||||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
float, cutlass::layout::RowMajor, 4,
|
float, cutlass::layout::RowMajor, 4,
|
||||||
@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
|
|||||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||||
MainLoopStages, Operator,
|
MainLoopStages, Operator,
|
||||||
1 /* epilogue stages */
|
1 /* epilogue stages */
|
||||||
>::GemmKernel;
|
>::GemmKernel>;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||||
@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 2>>(
|
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales);
|
b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
||||||
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||||
b_scales);
|
b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales);
|
b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
|
||||||
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
b_scales);
|
b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,16 +298,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales);
|
b_scales);
|
||||||
} else {
|
} else {
|
||||||
assert(out.dtype() == torch::kFloat16);
|
assert(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales);
|
b_scales);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
@ -280,15 +315,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) {
|
|||||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_or_later : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||||
typename ClusterShape, typename KernelSchedule,
|
typename ClusterShape, typename KernelSchedule,
|
||||||
typename EpilogueSchedule>
|
typename EpilogueSchedule>
|
||||||
@ -126,9 +141,9 @@ struct cutlass_3x_gemm {
|
|||||||
KernelSchedule>::CollectiveOp;
|
KernelSchedule>::CollectiveOp;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using KernelType = cutlass::gemm::kernel::GemmUniversal<
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||||
cutlass::gemm::PersistentScheduler>;
|
cutlass::gemm::PersistentScheduler>>;
|
||||||
|
|
||||||
struct GemmKernel : public KernelType {};
|
struct GemmKernel : public KernelType {};
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user