flash-attention/csrc/flash_attn
Tri Dao 5b838a8bef Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
..
cutlass@319a389f42 Add Cutlass as submodule 2022-06-02 09:54:16 -07:00
src Apply dropout scaling to dQ and dK instead of to V (in bwd) 2022-07-03 17:53:37 -07:00
fmha_api.cpp Apply dropout scaling to dQ and dK instead of to V (in bwd) 2022-07-03 17:53:37 -07:00