[CrossEntropy] Use online softmax to simplify implementation
This commit is contained in:
parent
32792d37ec
commit
d79f9b41a8
@ -34,7 +34,6 @@ def cross_entropy_fwd_kernel(
|
|||||||
total_classes,
|
total_classes,
|
||||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
n_rows,
|
|
||||||
logits_row_stride, # strides
|
logits_row_stride, # strides
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
HAS_SMOOTHING: tl.constexpr,
|
HAS_SMOOTHING: tl.constexpr,
|
||||||
@ -42,26 +41,30 @@ def cross_entropy_fwd_kernel(
|
|||||||
SPLIT: tl.constexpr,
|
SPLIT: tl.constexpr,
|
||||||
):
|
):
|
||||||
row_idx = tl.program_id(0)
|
row_idx = tl.program_id(0)
|
||||||
col_block_idx = tl.program_id(1)
|
|
||||||
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
||||||
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
sum_logits = 0.0 # For smoothing
|
||||||
|
# Statistics for online softmax
|
||||||
|
m_i = -float("inf")
|
||||||
|
l_i = 0.0
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
cols = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
|
||||||
|
tl.float32
|
||||||
|
) * logit_scale
|
||||||
|
if HAS_SMOOTHING:
|
||||||
|
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
|
||||||
|
m_i_new = tl.maximum(m_i, tl.max(logits))
|
||||||
|
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
|
||||||
|
m_i = m_i_new
|
||||||
|
lse = tl.log(l_i) + m_i
|
||||||
|
tl.store(lse_ptr + row_idx, lse)
|
||||||
label_idx = tl.load(labels_ptr + row_idx)
|
label_idx = tl.load(labels_ptr + row_idx)
|
||||||
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
|
||||||
tl.float32
|
|
||||||
) * logit_scale
|
|
||||||
max_logits = tl.max(logits, 0)
|
|
||||||
if HAS_SMOOTHING:
|
|
||||||
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
|
||||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
|
||||||
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
|
||||||
if label_idx == ignore_index:
|
if label_idx == ignore_index:
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
z_loss = 0.0
|
z_loss = 0.0
|
||||||
else:
|
else:
|
||||||
label_idx -= class_start_idx
|
label_idx -= class_start_idx
|
||||||
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
if label_idx >= 0 and label_idx < n_cols:
|
||||||
n_cols, (col_block_idx + 1) * BLOCK_SIZE
|
|
||||||
):
|
|
||||||
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
||||||
if HAS_SMOOTHING:
|
if HAS_SMOOTHING:
|
||||||
loss = (
|
loss = (
|
||||||
@ -82,9 +85,9 @@ def cross_entropy_fwd_kernel(
|
|||||||
loss += z_loss
|
loss += z_loss
|
||||||
else:
|
else:
|
||||||
z_loss = 0.0
|
z_loss = 0.0
|
||||||
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
tl.store(loss_ptr + row_idx, loss)
|
||||||
if not SPLIT:
|
if not SPLIT:
|
||||||
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
|
tl.store(z_loss_ptr + row_idx, z_loss)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
@ -161,27 +164,20 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
|
|
||||||
if logits.stride(-1) != 1:
|
if logits.stride(-1) != 1:
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
|
MAX_BLOCK_SIZE = 16 * 1024
|
||||||
MAX_BLOCK_SIZE = 64 * 1024
|
|
||||||
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
||||||
num_warps = (
|
num_warps = (
|
||||||
4
|
4
|
||||||
if BLOCK_SIZE < 2048
|
if BLOCK_SIZE < 2048
|
||||||
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
||||||
)
|
)
|
||||||
# We may split the lse computation across multiple blocks, then do a reduction
|
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||||
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
|
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||||
# where having just one thread block processing more than 64k elements is slow.
|
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||||
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
|
|
||||||
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
|
|
||||||
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
|
||||||
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
|
||||||
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
|
||||||
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
|
||||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||||
with torch.cuda.device(logits.device.index):
|
with torch.cuda.device(logits.device.index):
|
||||||
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
cross_entropy_fwd_kernel[(n_rows,)](
|
||||||
losses, # data ptrs
|
losses, # data ptrs
|
||||||
lse,
|
lse,
|
||||||
z_losses,
|
z_losses,
|
||||||
@ -194,23 +190,19 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
total_classes,
|
total_classes,
|
||||||
class_start_idx,
|
class_start_idx,
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
n_rows,
|
|
||||||
logits.stride(0), # strides
|
logits.stride(0), # strides
|
||||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
SPLIT=split,
|
SPLIT=world_size > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if split:
|
if world_size > 1:
|
||||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||||
# - predicted logit, and 0 otherwise.
|
# - predicted logit, and 0 otherwise.
|
||||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||||
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
||||||
# For labels not in the vocab of this partition, losses contains
|
# For labels not in the vocab of this partition, losses contains
|
||||||
# -0.1 * sum logit / total_classes.
|
# -0.1 * sum logit / total_classes.
|
||||||
if n_splits > 1:
|
|
||||||
lse = torch.logsumexp(lse, dim=0)
|
|
||||||
losses = losses.sum(dim=0)
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
||||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
||||||
@ -243,6 +235,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
ctx.class_start_idx = class_start_idx
|
ctx.class_start_idx = class_start_idx
|
||||||
ctx.inplace_backward = inplace_backward
|
ctx.inplace_backward = inplace_backward
|
||||||
|
|
||||||
|
|
||||||
return losses, z_losses
|
return losses, z_losses
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user