[CE] Implement CrossEntropyLoss in Triton

This commit is contained in:
Tri Dao 2023-09-15 19:27:18 -07:00
parent 56b7fc6ee0
commit 5400fdc4ac
5 changed files with 370 additions and 135 deletions

View File

@ -7,3 +7,8 @@ It has only been tested on A100s.
```sh
cd csrc/xentropy && pip install .
```
As of 2023-09-15, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py).
See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details.

View File

@ -1,116 +1,9 @@
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
# Copyright (c) 2023, Tri Dao.
import torch
import torch.nn as nn
import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch, vocab_size = logits.shape
assert labels.shape == (batch,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
ctx.total_classes = world_size * vocab_size
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size
)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
)
torch.distributed.all_gather_into_tensor(
lse_allgather, lse_local.contiguous(), group=process_group
)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
lse_local = lse_allgather[
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
]
handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0)
)
losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
)
return grad_logits, None, None, None, None, None, None
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
class CrossEntropyLoss(nn.Module):
@ -119,30 +12,52 @@ class CrossEntropyLoss(nn.Module):
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
lse_square_scale=0.0,
inplace_backward=False,
process_group=None,
):
"""
Arguments:
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
"""
super().__init__()
if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
if reduction not in ["mean", "none", "sum"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward
self.process_group = process_group
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
"""
Arguments:
input: (batch, vocab_size)
target: (batch,)
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
"""
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
loss = cross_entropy_loss(
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
label_smoothing=self.label_smoothing,
lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index,
inplace_backward=self.inplace_backward,
process_group=self.process_group,
)
if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
return loss.sum()
else:
return loss

View File

@ -0,0 +1,293 @@
# Copyright (c) 2023, Tri Dao.
from typing import Tuple, Optional, Union
import torch
from einops import rearrange
import triton
import triton.language as tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
logits_ptr,
labels_ptr,
smoothing,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
n_rows,
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
)
max_logits = tl.max(logits, 0)
if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
if label_idx == ignored_index:
loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE
):
logits_label = tl.load(logits_ptr + label_idx)
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- smoothing * sum_logits / total_classes
- (1 - smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if HAS_SMOOTHING:
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
loss += lse_square_scale * lse * lse
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
smoothing,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignored_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
)
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_positive = 1.0 - smoothing
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing,
lse_square_scale=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
if logits.stride(-1) != 1:
logits = logits.contiguous()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE = 64 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows, n_splits)](
losses, # data ptrs
lse,
logits,
labels,
smoothing,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx,
n_cols, # shapes
n_rows,
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
SPLIT=split,
)
if split:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
else:
lse = torch.logsumexp(lse, dim=0)
losses = losses.sum(dim=0)
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
losses += lse_square_scale * lse.square()
losses.masked_fill_(labels == ignored_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.smoothing = smoothing
ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_losses):
logits, lse, labels = ctx.saved_tensors
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
grad_losses,
logits,
lse,
labels,
ctx.smoothing,
ctx.lse_square_scale,
ctx.ignored_index,
ctx.total_classes,
ctx.class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
grad_losses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
label_smoothing: float = 0.0,
lse_square_scale: float = 0.0,
ignored_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
"""
return CrossEntropyLoss.apply(
logits,
labels,
label_smoothing,
lse_square_scale,
ignored_index,
inplace_backward,
process_group,
)

View File

@ -4,7 +4,7 @@ import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLossApex
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@ -12,12 +12,16 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
@pytest.mark.parametrize("vocab_size", [50257])
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype)
)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
if batch_size * seqlen > 10:
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward)
model = CrossEntropyLoss(
label_smoothing=smoothing,
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
g = torch.randn_like(out)
out_pt.backward(g)

View File

@ -1,5 +1,5 @@
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/losses/test_cross_entropy_parallel.py
import math
@ -15,15 +15,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize('smoothing', [0.9])
@pytest.mark.parametrize("vocab_size", [50264])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50264, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
@pytest.mark.parametrize("world_size", [1, 2, 4])
# @pytest.mark.parametrize("world_size", [2])
def test_cross_entropy_loss_parallel(
vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype
):
assert vocab_size % world_size == 0
rtol, atol = (
(1e-5, 1e-6)
@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
model = CrossEntropyLoss(
label_smoothing=smoothing,
reduction="none",
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
out_pt += lse_square_scale * lse_pt.square()
out_pt.masked_fill_(y == -100, 0.0)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
g = torch.randn_like(out)