* Fix MHA kernel
Summary:
ATT
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* Extend DualGemm to support batched mode (#5)
Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set.
* Decouple LayoutB0 and LayoutB1 in DualGemm
The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously.
In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above.
* Remove comment as no longer relevant
* Revert Fix MHA kernel
---------
Co-authored-by: mikeiovine <mikeiovine@fb.com>