* Update fMHA kernels
Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798
* minor changes
* make var work
---------
Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* [WIP] GEMM StreamK w/ Fused Epilogue
* Adds Gemm Streamk with Fused Epilogue kernel level struct.
* Mostly based on Gemm with Fused Epilogue,
* Requires a new epilogue
* Work in progress
* [WIP] StreamK support for GemmUniversalWithBroadcast
* Just based off of how StreamK is allowed in GemmUniversal
* Untested and a work in progress
* Minor fixes
* [WIP] It compiles!
It is almost certainly incorrect, but we're past getting the templates
to match, so checkpointing.
* Correction to reference kernel
* Fix typo
* Added MSE measurement
* Switch back to reference kernel + host for loop
Still WIP. Now we're getting even a larger MSE, but it's both on
basic Split-K and Stream-K.
* Fix typos
* Fix broadcast vector + requested changes
* Comment typo
* Small int option and more
* Fix incorrect condition on source needed
* Requested changes
* I think I got it?
* Bias vector should be stride 0
* Two source added!
* Typos
* Merge examples
* Bring back vector row offset
Just to ensure consistency with universal gemm with fused epilogue
* Base arguments and params structs for StreamK
* StreamK epilogue with broadcast now inherits the original
* undo params_streamk_base.h
---------
Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* added support of b2b bmm
* fixed arguments and params structures
* added batch_count argument
* removed SplitKSerial and added new test case with b2b bmm
* fixed support of Kbatched and added new test case with batch stride
* added batch support for bias and scale
* make test
* small changes
---------
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* Fix MHA kernel
Summary:
ATT
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* Extend DualGemm to support batched mode (#5)
Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set.
* Decouple LayoutB0 and LayoutB1 in DualGemm
The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously.
In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above.
* Remove comment as no longer relevant
* Revert Fix MHA kernel
---------
Co-authored-by: mikeiovine <mikeiovine@fb.com>
* xFormer updates to fMHA FW
* Convert format to BMHK for '41_fused_multi_head_attention_fixed_seqlen'
* Add missing files
* Remove xFormers specific code
* Update fused_multihead_attention_fixed_seqlen.cu
* rebase and solve conflicts
* remove white space
---------
Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* add two missing files
* fix bunch of bugs of gemm-reducek fusion and add a device interface
* small changes
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* ex42: Fused MHA imported from xFormers
* Remove std:: references
* Support K>128 in the example
* Support causal option
* Support different head size for V, and different seqlength for KV
* Update FLOPS counter
* Remove bit_cast
* fix build: Replace M_LOG2E
* Add doc
* Revert "Remove bit_cast"
This reverts commit 9662fa86bb7c57c1a015ac0bf52cb52940fbbf80.
* Explicit casts to int32_t for windows build
Co-authored-by: danthe3rd <danthe3rd>
* Remove redundant <fstream> includes
* Fix fstream in examples/
* Fix <fstream> in test/
* Use consistent order for <fstream> (always after <iostream>)
* Remove an unneeded include in a file where std::ofstream usage is commented out
Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>