* base version
* restructure pipelines, add special fp8 epilogue
* add variants
* add fp8 causal and modify dynamic tile scheduler
* better causal schedule
* maintain two schedules for non causal and causal
* removing macros
* fix regression
* clean up unneeded methods and variants
* fix mistake with NumProducerThreads
* base version
* restructure pipelines, add special fp8 epilogue
* add variants
* add fp8 causal and modify dynamic tile scheduler
* better causal schedule
* maintain two schedules for non causal and causal
* removing macros
* fix regression
* clean up unneeded methods and variants
* fix mistake with NumProducerThreads
* use seqlen traits
* add fp8 .cu files and benchmark script
* fix merge issue
* fix merge issue
* fix merge issue
* remove duplicate code
* fix regression with varseqlen
* move varseqlen init in constexpr
* fix test script
* more constexpr on varseqlen and add max offset
* add back test cases
* adding files for fp8 changes.
* removed contiguous check.
* enable all tests except odd-seq-lengths, where it crashes now.
* undid clang formatting.
* change to correct tile size for headdim=128.
* fixed odd-seq-len-k.
* minor formatting.
* minor reformatting.
---------
Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
* Support ck in fmha
* Add ck submodule
* Do not return lse if return_softmax == false
* Use receipt to speed up ck compile time
* Integrate new version of ck_tile
* Support dropout for mha_fwd()
* Add dropout to mha_varlen_fwd()
* Update ck to develop
* Extract padding function for dropout randval
* Extract randval transformation function
* Sync the code structure and coding style with FA
* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py
* fix compile error
* Add mha_bwd
* Generate dropout seed and offset from user generator
* update CK
* Add mha_varlen_bwd
* Use same python as build flash-attn to generate ck kernel
* Fix bug of group mode fwd about returning softmax lse
* larger the test tollerance
* Add test_flash_attn_output() and test_flash_attn_varlen_output()
* Always fill softmax_lse
* Remove duplicate benchmark script, since we already implement mha_bwd
* Refine get value from tuple
* Use default parameter for stream_config
* unblock all platform
* Add comment
* refine the test code
* Refine naming
* Add unpack to namespace
* Do not hardcode the warp size 64
* Add more targets
* Add README
* Optimize mha_fwd if seqlen_q == 1
* Support get_wheel_url for rocm
* Detect rocm environment by pytorch's IS_HIP_EXTENSION
* update to lastest ck
* Add necessary compile flag
* Sync the api with upstream FA
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
* check in the two ways of approaching backwards for softcapping, both functional
* prepare the softcap switch for backwards
* temporary
* cleanup to the way Tri prefers
* calculate dtanh when copying from scores -> dtanh Tensor
* no ternary operators allowed for constexpr, so just use some hack found online
* fix maybe_dtanh, restore some files
* restore another file
* move calculate_dtanh to utils and colocate with apply_softcap
* cleanup
* maybe last cleanup
* save for another pr
* remove a stray line
* fix spacing
* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.
The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel
# >>> transformers.__version__
# '4.38.2'
model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)
# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias'
```
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)