diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index 2b9eba76..cc84aa0b 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -118,6 +118,7 @@ try: _torch_to_library_dict = { torch.half: cutlass.DataType.f16, torch.float16: cutlass.DataType.f16, + torch.bfloat16: cutlass.DataType.bf16, torch.float: cutlass.DataType.f32, torch.float32: cutlass.DataType.f32, torch.double: cutlass.DataType.f64, @@ -127,6 +128,7 @@ try: _library_to_torch_dict = { cutlass.DataType.f16: torch.half, cutlass.DataType.f16: torch.float16, + cutlass.DataType.bf16: torch.bfloat16, cutlass.DataType.f32: torch.float, cutlass.DataType.f32: torch.float32, cutlass.DataType.f64: torch.double,