torch.bfloat16 support in cutlass python (#1037)

* torch.bfloat16 support in cutlass python

* Update datatypes.py
This commit is contained in:
Sophia Wisdom 2023-08-16 08:38:53 -07:00 committed by GitHub
parent 4575443d44
commit 2d9a557427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -118,6 +118,7 @@ try:
_torch_to_library_dict = { _torch_to_library_dict = {
torch.half: cutlass.DataType.f16, torch.half: cutlass.DataType.f16,
torch.float16: cutlass.DataType.f16, torch.float16: cutlass.DataType.f16,
torch.bfloat16: cutlass.DataType.bf16,
torch.float: cutlass.DataType.f32, torch.float: cutlass.DataType.f32,
torch.float32: cutlass.DataType.f32, torch.float32: cutlass.DataType.f32,
torch.double: cutlass.DataType.f64, torch.double: cutlass.DataType.f64,
@ -127,6 +128,7 @@ try:
_library_to_torch_dict = { _library_to_torch_dict = {
cutlass.DataType.f16: torch.half, cutlass.DataType.f16: torch.half,
cutlass.DataType.f16: torch.float16, cutlass.DataType.f16: torch.float16,
cutlass.DataType.bf16: torch.bfloat16,
cutlass.DataType.f32: torch.float, cutlass.DataType.f32: torch.float,
cutlass.DataType.f32: torch.float32, cutlass.DataType.f32: torch.float32,
cutlass.DataType.f64: torch.double, cutlass.DataType.f64: torch.double,