flash-attention/csrc/xentropy
2022-11-15 07:10:25 -08:00
..
interface.cpp Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
README.md Mention that some CUDA extensions have only been tested on A100s 2022-11-15 07:10:25 -08:00
setup.py Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
xentropy_kernel.cu Add fused cross entropy loss 2022-11-12 21:58:41 -08:00

This CUDA extension implements optimized cross-entropy loss, adapted from Apex's Xentropy. We make it work for bfloat16 and support in-place backward to save memory.

It has only been tested on A100s.

cd csrc/xentropy && pip install .