flash-attention/csrc/xentropy
2022-11-13 21:59:20 -08:00
..
interface.cpp Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
README.md Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00
setup.py Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
xentropy_kernel.cu Add fused cross entropy loss 2022-11-12 21:58:41 -08:00

This CUDA extension implements optimized cross-entropy loss, adapted from Apex's Xentropy. We make it work for bfloat16 and support in-place backward to save memory.

cd csrc/xentropy && pip install .