added mapping for bf16 to torch::kBFloat16 (#1843)

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Bogumil Sapinski Mobica 2024-10-23 18:48:31 +02:00 committed by GitHub
parent b0c09ed077
commit 83ae20c740
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -255,6 +255,7 @@ _CUTLASS_TYPE_TO_TORCH_TYPE = {
DataType.f64: "torch::kF64",
DataType.s8: "torch::kI8",
DataType.s32: "torch::kI32",
DataType.bf16: "torch::kBFloat16",
}
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (