[CE] Implement CrossEntropyLoss in Triton
This commit is contained in:
parent
56b7fc6ee0
commit
5400fdc4ac
@ -7,3 +7,8 @@ It has only been tested on A100s.
|
|||||||
```sh
|
```sh
|
||||||
cd csrc/xentropy && pip install .
|
cd csrc/xentropy && pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
As of 2023-09-15, this extension is no longer used in the FlashAttention repo.
|
||||||
|
We've instead switched to a Triton-based
|
||||||
|
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py).
|
||||||
|
See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details.
|
||||||
|
|||||||
@ -1,116 +1,9 @@
|
|||||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
# Copyright (c) 2023, Tri Dao.
|
||||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
|
||||||
# the losses we can get the global loss. There's no need to do it step by step
|
|
||||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
|
||||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import xentropy_cuda_lib
|
|
||||||
|
|
||||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
|
||||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
|
||||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
|
||||||
# older PyTorch.
|
|
||||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
|
||||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
smoothing=0.0,
|
|
||||||
ignored_index=-100,
|
|
||||||
inplace_backward=False,
|
|
||||||
process_group=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
logits: (batch, vocab_size)
|
|
||||||
labels: (batch,)
|
|
||||||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
|
|
||||||
one part of the vocab. The loss needs to be aggregated across processes.
|
|
||||||
"""
|
|
||||||
batch, vocab_size = logits.shape
|
|
||||||
assert labels.shape == (batch,)
|
|
||||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
|
||||||
ctx.total_classes = world_size * vocab_size
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
|
||||||
losses.masked_fill_(labels == ignored_index, 0)
|
|
||||||
labels_local = labels
|
|
||||||
else:
|
|
||||||
rank = torch.distributed.get_rank(process_group)
|
|
||||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
|
||||||
|
|
||||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
|
||||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
|
||||||
ignored_mask = labels == ignored_index
|
|
||||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
|
||||||
|
|
||||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
|
||||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
|
||||||
# last dimension of the input tensor.
|
|
||||||
losses, lse_local = xentropy_cuda_lib.forward(
|
|
||||||
logits, labels_local, smoothing, world_size * vocab_size
|
|
||||||
)
|
|
||||||
assert lse_local.shape == (batch,)
|
|
||||||
assert losses.shape == (batch,)
|
|
||||||
losses.masked_fill_(ignored_mask, 0)
|
|
||||||
# For labels == ignored_index, the loss is always 0.
|
|
||||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
|
||||||
# lse_local - predicted logit, and 0 otherwise.
|
|
||||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
|
||||||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
|
|
||||||
# For labels not in the vocab of this partition, losses contains
|
|
||||||
# 0.1 * (lse_local - sum logit / total_classes).
|
|
||||||
|
|
||||||
lse_allgather = torch.empty(
|
|
||||||
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
|
|
||||||
)
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
|
||||||
lse_allgather, lse_local.contiguous(), group=process_group
|
|
||||||
)
|
|
||||||
handle_losses = torch.distributed.all_reduce(
|
|
||||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
|
||||||
)
|
|
||||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
|
||||||
# If there's no smoothing, the total losses are lse_local - predicted_logit,
|
|
||||||
# we just have to subtract the lse_local and add the lse (global).
|
|
||||||
# If there's smoothing=0.1, the total losses are
|
|
||||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
|
||||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
|
||||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
|
|
||||||
lse_local = lse_allgather[
|
|
||||||
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
|
|
||||||
]
|
|
||||||
|
|
||||||
handle_losses.wait()
|
|
||||||
if smoothing == 0.0:
|
|
||||||
losses += lse - lse_local
|
|
||||||
else:
|
|
||||||
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
|
|
||||||
lse - lse_allgather.sum(dim=0)
|
|
||||||
)
|
|
||||||
losses.masked_fill_(ignored_mask, 0)
|
|
||||||
|
|
||||||
ctx.save_for_backward(logits, lse, labels_local)
|
|
||||||
ctx.smoothing = smoothing
|
|
||||||
ctx.ignored_index = ignored_index
|
|
||||||
ctx.inplace_backward = inplace_backward
|
|
||||||
return losses
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_loss):
|
|
||||||
logits, lse, labels = ctx.saved_tensors
|
|
||||||
grad_loss = grad_loss.contiguous()
|
|
||||||
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
|
|
||||||
grad_logits = xentropy_cuda_lib.backward(
|
|
||||||
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
|
|
||||||
)
|
|
||||||
return grad_logits, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(nn.Module):
|
class CrossEntropyLoss(nn.Module):
|
||||||
@ -119,30 +12,52 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
reduction="mean",
|
reduction="mean",
|
||||||
label_smoothing=0.0,
|
label_smoothing=0.0,
|
||||||
|
lse_square_scale=0.0,
|
||||||
inplace_backward=False,
|
inplace_backward=False,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||||
|
label_smoothing: float
|
||||||
|
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||||
|
This is also referred to as "z-loss".
|
||||||
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||||
|
This saves memory.
|
||||||
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||||
|
one part of the vocab. The loss will be aggregated across processes.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if reduction not in ["mean", "none"]:
|
if reduction not in ["mean", "none", "sum"]:
|
||||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
self.lse_square_scale = lse_square_scale
|
||||||
self.inplace_backward = inplace_backward
|
self.inplace_backward = inplace_backward
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
assert input.is_cuda and target.is_cuda
|
"""
|
||||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
Arguments:
|
||||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
input: (batch, vocab_size)
|
||||||
|
target: (batch,)
|
||||||
|
Returns:
|
||||||
|
losses: (batch,) if reduction is 'none', else (1,), dtype float
|
||||||
|
"""
|
||||||
|
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
|
||||||
|
loss = cross_entropy_loss(
|
||||||
input,
|
input,
|
||||||
target,
|
target,
|
||||||
self.label_smoothing,
|
label_smoothing=self.label_smoothing,
|
||||||
self.ignore_index,
|
lse_square_scale=self.lse_square_scale,
|
||||||
self.inplace_backward,
|
ignored_index=self.ignore_index,
|
||||||
self.process_group,
|
inplace_backward=self.inplace_backward,
|
||||||
|
process_group=self.process_group,
|
||||||
)
|
)
|
||||||
if self.reduction == "mean":
|
if self.reduction == "mean":
|
||||||
return loss.sum() / (target != self.ignore_index).sum()
|
return loss.sum() / (target != self.ignore_index).sum()
|
||||||
|
elif self.reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
else:
|
else:
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
293
flash_attn/ops/triton/cross_entropy.py
Normal file
293
flash_attn/ops/triton/cross_entropy.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
# Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
|
from typing import Tuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||||
|
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||||
|
# version of PyTorch. The following 2 lines are for backward compatibility with
|
||||||
|
# older PyTorch.
|
||||||
|
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||||
|
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def cross_entropy_fwd_kernel(
|
||||||
|
loss_ptr, # data ptrs
|
||||||
|
lse_ptr,
|
||||||
|
logits_ptr,
|
||||||
|
labels_ptr,
|
||||||
|
smoothing,
|
||||||
|
lse_square_scale,
|
||||||
|
ignored_index,
|
||||||
|
total_classes,
|
||||||
|
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||||
|
n_cols, # shapes
|
||||||
|
n_rows,
|
||||||
|
logits_row_stride, # strides
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
HAS_SMOOTHING: tl.constexpr,
|
||||||
|
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
|
||||||
|
SPLIT: tl.constexpr,
|
||||||
|
):
|
||||||
|
row_idx = tl.program_id(0)
|
||||||
|
col_block_idx = tl.program_id(1)
|
||||||
|
logits_ptr = logits_ptr + row_idx * logits_row_stride
|
||||||
|
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
label_idx = tl.load(labels_ptr + row_idx)
|
||||||
|
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||||
|
tl.float32
|
||||||
|
)
|
||||||
|
max_logits = tl.max(logits, 0)
|
||||||
|
if HAS_SMOOTHING:
|
||||||
|
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
||||||
|
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||||
|
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
||||||
|
if label_idx == ignored_index:
|
||||||
|
loss = 0.0
|
||||||
|
else:
|
||||||
|
label_idx -= class_start_idx
|
||||||
|
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
||||||
|
n_cols, (col_block_idx + 1) * BLOCK_SIZE
|
||||||
|
):
|
||||||
|
logits_label = tl.load(logits_ptr + label_idx)
|
||||||
|
if HAS_SMOOTHING:
|
||||||
|
loss = (
|
||||||
|
(lse if not SPLIT else 0.0)
|
||||||
|
- smoothing * sum_logits / total_classes
|
||||||
|
- (1 - smoothing) * logits_label
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = (lse if not SPLIT else 0.0) - logits_label
|
||||||
|
else:
|
||||||
|
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
|
||||||
|
if HAS_SMOOTHING:
|
||||||
|
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
|
||||||
|
else:
|
||||||
|
loss = 0.0
|
||||||
|
if not SPLIT:
|
||||||
|
loss += lse_square_scale * lse * lse
|
||||||
|
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def cross_entropy_bwd_kernel(
|
||||||
|
dlogits_ptr, # data ptrs
|
||||||
|
dloss_ptr,
|
||||||
|
logits_ptr,
|
||||||
|
lse_ptr,
|
||||||
|
labels_ptr,
|
||||||
|
smoothing,
|
||||||
|
lse_square_scale,
|
||||||
|
ignored_index,
|
||||||
|
total_classes,
|
||||||
|
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||||
|
n_cols, # shapes
|
||||||
|
logits_row_stride, # strides
|
||||||
|
dlogits_row_stride,
|
||||||
|
dloss_row_stride,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
HAS_SMOOTHING: tl.constexpr,
|
||||||
|
):
|
||||||
|
row_idx = tl.program_id(0)
|
||||||
|
col_block_idx = tl.program_id(1)
|
||||||
|
logits_ptr = logits_ptr + row_idx * logits_row_stride
|
||||||
|
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride
|
||||||
|
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
label_idx = tl.load(labels_ptr + row_idx)
|
||||||
|
if label_idx != ignored_index:
|
||||||
|
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
||||||
|
else:
|
||||||
|
dloss = 0.0
|
||||||
|
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||||
|
tl.float32
|
||||||
|
)
|
||||||
|
lse = tl.load(lse_ptr + row_idx)
|
||||||
|
probs = tl.exp(logits - lse)
|
||||||
|
probs += 2.0 * lse_square_scale * lse * probs
|
||||||
|
label_idx -= class_start_idx
|
||||||
|
if HAS_SMOOTHING:
|
||||||
|
smooth_positive = 1.0 - smoothing
|
||||||
|
smooth_negative = smoothing / total_classes
|
||||||
|
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
|
||||||
|
else:
|
||||||
|
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||||
|
tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyLoss(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
smoothing,
|
||||||
|
lse_square_scale=0.0,
|
||||||
|
ignored_index=-100,
|
||||||
|
inplace_backward=False,
|
||||||
|
process_group=None,
|
||||||
|
):
|
||||||
|
n_rows, n_cols = logits.shape
|
||||||
|
assert labels.shape == (n_rows,)
|
||||||
|
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||||
|
total_classes = world_size * n_cols
|
||||||
|
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
|
||||||
|
class_start_idx = rank * n_cols
|
||||||
|
|
||||||
|
if logits.stride(-1) != 1:
|
||||||
|
logits = logits.contiguous()
|
||||||
|
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
|
||||||
|
MAX_BLOCK_SIZE = 64 * 1024
|
||||||
|
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
||||||
|
num_warps = (
|
||||||
|
4
|
||||||
|
if BLOCK_SIZE < 2048
|
||||||
|
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
||||||
|
)
|
||||||
|
# We may split the lse computation across multiple blocks, then do a reduction
|
||||||
|
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
|
||||||
|
# where having just one thread block processing more than 64k elements is slow.
|
||||||
|
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
|
||||||
|
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||||
|
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
||||||
|
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||||
|
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||||
|
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||||
|
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||||
|
with torch.cuda.device(logits.device.index):
|
||||||
|
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
||||||
|
losses, # data ptrs
|
||||||
|
lse,
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
smoothing,
|
||||||
|
lse_square_scale,
|
||||||
|
ignored_index,
|
||||||
|
total_classes,
|
||||||
|
class_start_idx,
|
||||||
|
n_cols, # shapes
|
||||||
|
n_rows,
|
||||||
|
logits.stride(0), # strides
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||||
|
num_warps=num_warps,
|
||||||
|
SPLIT=split,
|
||||||
|
)
|
||||||
|
|
||||||
|
if split:
|
||||||
|
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||||
|
# - predicted logit, and 0 otherwise.
|
||||||
|
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||||
|
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
||||||
|
# For labels not in the vocab of this partition, losses contains
|
||||||
|
# -0.1 * sum logit / total_classes.
|
||||||
|
if world_size > 1:
|
||||||
|
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
||||||
|
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
||||||
|
handle_losses = torch.distributed.all_reduce(
|
||||||
|
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||||
|
)
|
||||||
|
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||||
|
handle_losses.wait()
|
||||||
|
else:
|
||||||
|
lse = torch.logsumexp(lse, dim=0)
|
||||||
|
losses = losses.sum(dim=0)
|
||||||
|
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
|
||||||
|
# we just have to add the (global) lse.
|
||||||
|
# If there's smoothing=0.1, the total losses are
|
||||||
|
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
|
||||||
|
# Again, we just have to add the (global) lse.
|
||||||
|
losses += lse
|
||||||
|
if lse_square_scale != 0.0:
|
||||||
|
losses += lse_square_scale * lse.square()
|
||||||
|
losses.masked_fill_(labels == ignored_index, 0.0)
|
||||||
|
|
||||||
|
ctx.save_for_backward(logits, lse, labels)
|
||||||
|
ctx.smoothing = smoothing
|
||||||
|
ctx.lse_square_scale = lse_square_scale
|
||||||
|
ctx.ignored_index = ignored_index
|
||||||
|
ctx.total_classes = total_classes
|
||||||
|
ctx.class_start_idx = class_start_idx
|
||||||
|
ctx.inplace_backward = inplace_backward
|
||||||
|
return losses
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_losses):
|
||||||
|
logits, lse, labels = ctx.saved_tensors
|
||||||
|
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
|
||||||
|
n_rows, n_cols = logits.shape
|
||||||
|
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
|
||||||
|
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
|
||||||
|
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
|
||||||
|
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||||
|
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||||
|
with torch.cuda.device(logits.device.index):
|
||||||
|
cross_entropy_bwd_kernel[grid](
|
||||||
|
dlogits, # data ptrs
|
||||||
|
grad_losses,
|
||||||
|
logits,
|
||||||
|
lse,
|
||||||
|
labels,
|
||||||
|
ctx.smoothing,
|
||||||
|
ctx.lse_square_scale,
|
||||||
|
ctx.ignored_index,
|
||||||
|
ctx.total_classes,
|
||||||
|
ctx.class_start_idx,
|
||||||
|
n_cols, # shapes
|
||||||
|
logits.stride(0), # strides
|
||||||
|
dlogits.stride(0),
|
||||||
|
grad_losses.stride(0),
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return dlogits, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy_loss(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
label_smoothing: float = 0.0,
|
||||||
|
lse_square_scale: float = 0.0,
|
||||||
|
ignored_index=-100,
|
||||||
|
inplace_backward: bool = False,
|
||||||
|
process_group=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
logits: (batch, vocab_size)
|
||||||
|
labels: (batch,)
|
||||||
|
label_smoothing: float
|
||||||
|
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||||
|
This is also referred to as "z-loss".
|
||||||
|
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||||
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||||
|
This saves memory.
|
||||||
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||||
|
one part of the vocab. The loss will be aggregated across processes.
|
||||||
|
Returns:
|
||||||
|
losses: (batch,), float
|
||||||
|
"""
|
||||||
|
return CrossEntropyLoss.apply(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
label_smoothing,
|
||||||
|
lse_square_scale,
|
||||||
|
ignored_index,
|
||||||
|
inplace_backward,
|
||||||
|
process_group,
|
||||||
|
)
|
||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLossApex
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||||
|
|
||||||
@ -12,12 +12,16 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
||||||
)
|
)
|
||||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
@pytest.mark.parametrize("inplace_backward", [False, True])
|
@pytest.mark.parametrize("inplace_backward", [False, True])
|
||||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||||
|
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||||
|
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
|
||||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||||
@pytest.mark.parametrize("vocab_size", [50257])
|
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||||
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
|
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
|
||||||
|
# @pytest.mark.parametrize("vocab_size", [12])
|
||||||
|
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||||
# set seed
|
# set seed
|
||||||
@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype)
|
|||||||
)
|
)
|
||||||
x = x_pt.detach().clone().requires_grad_()
|
x = x_pt.detach().clone().requires_grad_()
|
||||||
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
||||||
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
if batch_size * seqlen > 10:
|
||||||
|
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
||||||
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
|
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
|
||||||
model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward)
|
model = CrossEntropyLoss(
|
||||||
|
label_smoothing=smoothing,
|
||||||
|
lse_square_scale=lse_square_scale,
|
||||||
|
inplace_backward=inplace_backward,
|
||||||
|
)
|
||||||
out = model(x, y)
|
out = model(x, y)
|
||||||
out_pt = model_pt(x_pt.float(), y)
|
out_pt = model_pt(x_pt.float(), y)
|
||||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
|
if lse_square_scale > 0.0:
|
||||||
|
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
|
||||||
|
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
|
||||||
|
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
g = torch.randn_like(out)
|
g = torch.randn_like(out)
|
||||||
out_pt.backward(g)
|
out_pt.backward(g)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# Run test with:
|
# Run test with:
|
||||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
|
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/losses/test_cross_entropy_parallel.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -15,15 +15,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
||||||
)
|
)
|
||||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
@pytest.mark.parametrize("inplace_backward", [False, True])
|
@pytest.mark.parametrize("inplace_backward", [False, True])
|
||||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||||
|
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||||
|
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
|
||||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||||
# @pytest.mark.parametrize('smoothing', [0.9])
|
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||||
@pytest.mark.parametrize("vocab_size", [50264])
|
@pytest.mark.parametrize("vocab_size", [50264, 128 * 1024]) # test vocab larger than 64k for split
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
|
||||||
# @pytest.mark.parametrize('world_size', [2])
|
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||||
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
|
# @pytest.mark.parametrize("world_size", [2])
|
||||||
|
def test_cross_entropy_loss_parallel(
|
||||||
|
vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype
|
||||||
|
):
|
||||||
assert vocab_size % world_size == 0
|
assert vocab_size % world_size == 0
|
||||||
rtol, atol = (
|
rtol, atol = (
|
||||||
(1e-5, 1e-6)
|
(1e-5, 1e-6)
|
||||||
@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
|
|||||||
model = CrossEntropyLoss(
|
model = CrossEntropyLoss(
|
||||||
label_smoothing=smoothing,
|
label_smoothing=smoothing,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
lse_square_scale=lse_square_scale,
|
||||||
inplace_backward=inplace_backward,
|
inplace_backward=inplace_backward,
|
||||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||||
)
|
)
|
||||||
out = model(x, y)
|
out = model(x, y)
|
||||||
out_pt = model_pt(x_pt.float(), y)
|
out_pt = model_pt(x_pt.float(), y)
|
||||||
|
if lse_square_scale > 0.0:
|
||||||
|
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
|
||||||
|
out_pt += lse_square_scale * lse_pt.square()
|
||||||
|
out_pt.masked_fill_(y == -100, 0.0)
|
||||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
g = torch.randn_like(out)
|
g = torch.randn_like(out)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user