108 lines
4.1 KiB
Plaintext
108 lines
4.1 KiB
Plaintext
#include <cudaTypedefs.h>
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
c10::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
c10::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
c10::optional<torch::Tensor> const& bias);
|
|
|
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
c10::optional<torch::Tensor> const& bias);
|
|
#endif
|
|
|
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|
// CUTLASS FP8 kernels need at least
|
|
// CUDA 12.0 on SM90 systems (Hopper)
|
|
// CUDA 12.4 on SM89 systems (Lovelace)
|
|
|
|
#if defined CUDA_VERSION
|
|
if (cuda_device_capability >= 90) {
|
|
return CUDA_VERSION >= 12000;
|
|
} else if (cuda_device_capability >= 89) {
|
|
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
|
|
// and are slower than torch.mm. Return false unconditionally in this case.
|
|
return false;
|
|
|
|
// Once the CUTLASS kernels have been optimized for Lovelace systems,
|
|
// use the following check:
|
|
// return CUDA_VERSION >= 12040;
|
|
}
|
|
#endif
|
|
|
|
return false;
|
|
}
|
|
|
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
c10::optional<torch::Tensor> const& bias) {
|
|
int32_t major_capability;
|
|
int32_t minor_capability;
|
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
|
0);
|
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
|
0);
|
|
int32_t version_num = major_capability * 10 + minor_capability;
|
|
|
|
// Checks for conformality
|
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
b.size(1) == c.size(1));
|
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
|
|
|
// Check for strides and alignment
|
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
|
|
if (bias) {
|
|
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
|
bias->dim() == 1);
|
|
}
|
|
|
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
|
|
|
if (version_num >= 90) {
|
|
// Hopper
|
|
|
|
// Guard against compilation issues for sm90 kernels
|
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
|
#else
|
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
|
#endif
|
|
} else if (version_num == 89) {
|
|
// Ada Lovelace
|
|
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
|
} else if (version_num >= 80) {
|
|
// Ampere
|
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
|
} else {
|
|
// Turing
|
|
TORCH_CHECK(version_num >= 75);
|
|
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
|
}
|
|
}
|