[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715)

This commit is contained in:
Tyler Michael Smith 2024-06-20 14:36:10 -04:00 committed by GitHub
parent a7dcc62086
commit 3f3b6b2150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 12 deletions

View File

@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

View File

@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {

View File

@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
&cutlass_scaled_mm_supports_fp8);
#endif
// Quantized GEMM for GPTQ.

View File

@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:

View File

@ -20,19 +20,8 @@ logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 89:
gpu_is_supported = version > 124
return gpu_is_supported
return ops.cutlass_scaled_mm_supports_fp8(capability)
class Fp8Config(QuantizationConfig):