cutlass/examples
Adnan Akhundov 3c995c7606
Extend DualGemm: support batched mode + decouple B0/B1 layouts (#790)
* Fix MHA kernel

Summary:

ATT

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Extend DualGemm to support batched mode (#5)

Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set.

* Decouple LayoutB0 and LayoutB1 in DualGemm

The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously.

In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above.

* Remove comment as no longer relevant

* Revert Fix MHA kernel

---------

Co-authored-by: mikeiovine <mikeiovine@fb.com>
2023-02-13 15:27:13 -05:00
..
00_basic_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
01_cutlass_utilities New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
02_dump_reg_shmem New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
03_visualize_layout New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
04_tile_iterator New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
05_batched_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
06_splitK_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
07_volta_tensorop_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
08_turing_tensorop_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
09_turing_tensorop_conv2dfprop New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
10_planar_complex CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
11_planar_complex_array CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
12_gemm_bias_relu New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
13_two_tensor_op_fusion New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
14_ampere_tf32_tensorop_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
15_ampere_sparse_tensorop_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
16_ampere_tensorop_conv2dfprop New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
17_fprop_per_channel_bias New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
18_ampere_fp64_tensorop_affine2_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
19_tensorop_canonical New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
20_simt_canonical New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
21_quaternion_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
22_quaternion_conv New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
23_ampere_gemm_operand_reduction_fusion New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
24_gemm_grouped New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
25_ampere_fprop_mainloop_fusion New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
26_ampere_wgrad_mainloop_fusion New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
27_ampere_3xtf32_fast_accurate_tensorop_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
28_ampere_3xtf32_fast_accurate_tensorop_fprop New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
30_wgrad_split_k New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
31_basic_syrk New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
32_basic_trmm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
33_ampere_3xtf32_tensorop_symm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
34_transposed_conv2d New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
35_gemm_softmax New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
36_gather_scatter_fusion New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
37_gemm_layernorm_gemm_fusion CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
38_syr2k_grouped New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
39_gemm_permute New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
40_cutlass_py New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
41_fused_multi_head_attention xFormer updates to fMHA FW (#773) 2023-02-08 23:00:10 -05:00
42_ampere_tensorop_group_conv New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
43_ell_block_sparse_gemm New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
44_multi_gemm_ir_and_codegen New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
45_dual_gemm Extend DualGemm: support batched mode + decouple B0/B1 layouts (#790) 2023-02-13 15:27:13 -05:00
46_depthwise_simt_conv2dfprop New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
47_ampere_gemm_universal_streamk CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
48_hopper_warp_specialized_gemm CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
49_hopper_gemm_schedules_with_collective_builder CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
50_hopper_gemm_with_epilogue_swizzle CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
60_cutlass_import New updates for 2.11 (#775) 2023-01-20 16:32:57 -05:00
common streamk example and performance tuning (#760) 2023-01-10 16:10:02 -05:00
cute CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00
CMakeLists.txt CUTLASS 3.0.0 (#786) 2023-01-23 20:55:28 -05:00