From 2d9a55742703311061558f005683aa36c5c3be84 Mon Sep 17 00:00:00 2001 From: Sophia Wisdom Date: Wed, 16 Aug 2023 08:38:53 -0700 Subject: [PATCH] torch.bfloat16 support in cutlass python (#1037) * torch.bfloat16 support in cutlass python * Update datatypes.py --- python/cutlass/utils/datatypes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index 2b9eba76..cc84aa0b 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -118,6 +118,7 @@ try: _torch_to_library_dict = { torch.half: cutlass.DataType.f16, torch.float16: cutlass.DataType.f16, + torch.bfloat16: cutlass.DataType.bf16, torch.float: cutlass.DataType.f32, torch.float32: cutlass.DataType.f32, torch.double: cutlass.DataType.f64, @@ -127,6 +128,7 @@ try: _library_to_torch_dict = { cutlass.DataType.f16: torch.half, cutlass.DataType.f16: torch.float16, + cutlass.DataType.bf16: torch.bfloat16, cutlass.DataType.f32: torch.float, cutlass.DataType.f32: torch.float32, cutlass.DataType.f64: torch.double,