flash-attention/csrc/xentropy/README.md

278 B

This CUDA extension implements optimized cross-entropy loss, adapted from Apex's Xentropy. We make it work for bfloat16 and support in-place backward to save memory.

cd csrc/xentropy && pip install .