flash-attention/csrc/xentropy
2023-03-15 16:59:27 -07:00
..
interface.cpp Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00
README.md Mention that some CUDA extensions have only been tested on A100s 2022-11-15 07:10:25 -08:00
setup.py Support H100 for other CUDA extensions 2023-03-15 16:59:27 -07:00
xentropy_kernel.cu Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00

This CUDA extension implements optimized cross-entropy loss, adapted from Apex's Xentropy. We make it work for bfloat16 and support in-place backward to save memory.

It has only been tested on A100s.

cd csrc/xentropy && pip install .