when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True. So we don't need to wait for handle. Just skip. |
||
|---|---|---|
| .. | ||
| layers | ||
| losses | ||
| models | ||
| modules | ||
| ops | ||
| utils | ||
| __init__.py | ||
| bert_padding.py | ||
| flash_attn_interface.py | ||
| flash_attn_triton_og.py | ||
| flash_attn_triton.py | ||
| flash_blocksparse_attention.py | ||
| flash_blocksparse_attn_interface.py | ||
| fused_softmax.py | ||
| pyproject.toml | ||