flash-attention/csrc/xentropy
2024-03-13 20:46:57 -07:00
..
interface.cpp Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00
README.md [CE] Implement CrossEntropyLoss in Triton 2023-09-15 20:05:28 -07:00
setup.py Make nvcc threads configurable via environment variable (#885) 2024-03-13 20:46:57 -07:00
xentropy_kernel.cu Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00

This CUDA extension implements optimized cross-entropy loss, adapted from Apex's Xentropy. We make it work for bfloat16 and support in-place backward to save memory.

It has only been tested on A100s.

cd csrc/xentropy && pip install .

As of 2023-09-15, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation. See the CrossEntropyLoss module for more details.