torch.bfloat16 support in cutlass python (#1037)
* torch.bfloat16 support in cutlass python * Update datatypes.py
This commit is contained in:
parent
4575443d44
commit
2d9a557427
@ -118,6 +118,7 @@ try:
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass.DataType.f16,
|
||||
torch.float16: cutlass.DataType.f16,
|
||||
torch.bfloat16: cutlass.DataType.bf16,
|
||||
torch.float: cutlass.DataType.f32,
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
@ -127,6 +128,7 @@ try:
|
||||
_library_to_torch_dict = {
|
||||
cutlass.DataType.f16: torch.half,
|
||||
cutlass.DataType.f16: torch.float16,
|
||||
cutlass.DataType.bf16: torch.bfloat16,
|
||||
cutlass.DataType.f32: torch.float,
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
|
Loading…
Reference in New Issue
Block a user