diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index f7bdfd74..8c10f87a 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -255,6 +255,7 @@ _CUTLASS_TYPE_TO_TORCH_TYPE = { DataType.f64: "torch::kF64", DataType.s8: "torch::kI8", DataType.s32: "torch::kI32", + DataType.bf16: "torch::kBFloat16", } _PYTORCH_GEMM_IMPL_TEMPLATE_2x = (