diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 1782338..e7bb686 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -34,7 +34,6 @@ def cross_entropy_fwd_kernel( total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes - n_rows, logits_row_stride, # strides BLOCK_SIZE: tl.constexpr, HAS_SMOOTHING: tl.constexpr, @@ -42,26 +41,30 @@ def cross_entropy_fwd_kernel( SPLIT: tl.constexpr, ): row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + sum_logits = 0.0 # For smoothing + # Statistics for online softmax + m_i = -float("inf") + l_i = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + if HAS_SMOOTHING: + sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) + m_i_new = tl.maximum(m_i, tl.max(logits)) + l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) + m_i = m_i_new + lse = tl.log(l_i) + m_i + tl.store(lse_ptr + row_idx, lse) label_idx = tl.load(labels_ptr + row_idx) - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - max_logits = tl.max(logits, 0) - if HAS_SMOOTHING: - sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) - lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits - tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) if label_idx == ignore_index: loss = 0.0 z_loss = 0.0 else: label_idx -= class_start_idx - if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( - n_cols, (col_block_idx + 1) * BLOCK_SIZE - ): + if label_idx >= 0 and label_idx < n_cols: logits_label = tl.load(logits_ptr + label_idx) * logit_scale if HAS_SMOOTHING: loss = ( @@ -82,9 +85,9 @@ def cross_entropy_fwd_kernel( loss += z_loss else: z_loss = 0.0 - tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + tl.store(loss_ptr + row_idx, loss) if not SPLIT: - tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) + tl.store(z_loss_ptr + row_idx, z_loss) @triton.heuristics( @@ -161,27 +164,20 @@ class CrossEntropyLoss(torch.autograd.Function): if logits.stride(-1) != 1: logits = logits.contiguous() - # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py - MAX_BLOCK_SIZE = 64 * 1024 + MAX_BLOCK_SIZE = 16 * 1024 BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) num_warps = ( 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) ) - # We may split the lse computation across multiple blocks, then do a reduction - # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) - # where having just one thread block processing more than 64k elements is slow. - split = world_size > 1 or n_cols > MAX_BLOCK_SIZE - n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE - loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) - losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) - lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) - z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) + z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(logits.device.index): - cross_entropy_fwd_kernel[(n_rows, n_splits)]( + cross_entropy_fwd_kernel[(n_rows,)]( losses, # data ptrs lse, z_losses, @@ -194,23 +190,19 @@ class CrossEntropyLoss(torch.autograd.Function): total_classes, class_start_idx, n_cols, # shapes - n_rows, logits.stride(0), # strides BLOCK_SIZE=BLOCK_SIZE, # constants num_warps=num_warps, - SPLIT=split, + SPLIT=world_size > 1, ) - if split: + if world_size > 1: # If there's no smoothing, if labels are in the vocab of this partition, losses contains # - predicted logit, and 0 otherwise. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains # -0.9 * predicted logit - 0.1 * sum logit / total_classes. # For labels not in the vocab of this partition, losses contains # -0.1 * sum logit / total_classes. - if n_splits > 1: - lse = torch.logsumexp(lse, dim=0) - losses = losses.sum(dim=0) if world_size > 1: lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) @@ -243,6 +235,7 @@ class CrossEntropyLoss(torch.autograd.Function): ctx.class_start_idx = class_start_idx ctx.inplace_backward = inplace_backward + return losses, z_losses @staticmethod