torch.bfloat16 support in cutlass python (#1037)
* torch.bfloat16 support in cutlass python * Update datatypes.py
This commit is contained in:
parent
4575443d44
commit
2d9a557427
@ -118,6 +118,7 @@ try:
|
|||||||
_torch_to_library_dict = {
|
_torch_to_library_dict = {
|
||||||
torch.half: cutlass.DataType.f16,
|
torch.half: cutlass.DataType.f16,
|
||||||
torch.float16: cutlass.DataType.f16,
|
torch.float16: cutlass.DataType.f16,
|
||||||
|
torch.bfloat16: cutlass.DataType.bf16,
|
||||||
torch.float: cutlass.DataType.f32,
|
torch.float: cutlass.DataType.f32,
|
||||||
torch.float32: cutlass.DataType.f32,
|
torch.float32: cutlass.DataType.f32,
|
||||||
torch.double: cutlass.DataType.f64,
|
torch.double: cutlass.DataType.f64,
|
||||||
@ -127,6 +128,7 @@ try:
|
|||||||
_library_to_torch_dict = {
|
_library_to_torch_dict = {
|
||||||
cutlass.DataType.f16: torch.half,
|
cutlass.DataType.f16: torch.half,
|
||||||
cutlass.DataType.f16: torch.float16,
|
cutlass.DataType.f16: torch.float16,
|
||||||
|
cutlass.DataType.bf16: torch.bfloat16,
|
||||||
cutlass.DataType.f32: torch.float,
|
cutlass.DataType.f32: torch.float,
|
||||||
cutlass.DataType.f32: torch.float32,
|
cutlass.DataType.f32: torch.float32,
|
||||||
cutlass.DataType.f64: torch.double,
|
cutlass.DataType.f64: torch.double,
|
||||||
|
Loading…
Reference in New Issue
Block a user