[CrossEntropy] Support precomputed LSE

This commit is contained in:
Tri Dao 2024-09-08 09:24:18 -07:00
parent e371bea04f
commit c7f32a8409
4 changed files with 79 additions and 33 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Tri Dao.
import torch
import torch.nn as nn
@ -44,7 +44,7 @@ class CrossEntropyLoss(nn.Module):
self.process_group = process_group
self.return_z_loss = return_z_loss
def forward(self, input, target):
def forward(self, input, target, precomputed_lse=None):
"""
Arguments:
input: (batch, vocab_size)
@ -57,6 +57,7 @@ class CrossEntropyLoss(nn.Module):
loss, z_loss = cross_entropy_loss(
input,
target,
precomputed_lse=precomputed_lse,
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,

View File

@ -39,25 +39,29 @@ def cross_entropy_fwd_kernel(
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
):
row_idx = tl.program_id(0)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
sum_logits = 0.0 # For smoothing
# Statistics for online softmax
m_i = -float("inf")
l_i = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
cols = col_offset + tl.arange(0, BLOCK_SIZE)
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
if HAS_SMOOTHING:
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
m_i_new = tl.maximum(m_i, tl.max(logits))
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
m_i = m_i_new
lse = tl.log(l_i) + m_i
tl.store(lse_ptr + row_idx, lse)
if not PRECOMPUTED_LSE:
# Statistics for online softmax
m_i = -float("inf")
l_i = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
cols = col_offset + tl.arange(0, BLOCK_SIZE)
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
if HAS_SMOOTHING:
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
m_i_new = tl.maximum(m_i, tl.max(logits))
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
m_i = m_i_new
lse = tl.log(l_i) + m_i
tl.store(lse_ptr + row_idx, lse)
else:
lse = tl.load(lse_ptr + row_idx)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx == ignore_index:
loss = 0.0
@ -135,7 +139,7 @@ def cross_entropy_bwd_kernel(
if HAS_SMOOTHING:
smooth_positive = 1.0 - smoothing
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
@ -148,6 +152,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx,
logits,
labels,
precomputed_lse=None,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
@ -161,6 +166,7 @@ class CrossEntropyLoss(torch.autograd.Function):
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
if logits.stride(-1) != 1:
logits = logits.contiguous()
@ -172,7 +178,11 @@ class CrossEntropyLoss(torch.autograd.Function):
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
if use_precomputed_lse:
assert precomputed_lse.shape == (n_rows,)
lse = precomputed_lse.contiguous()
else:
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
@ -192,8 +202,9 @@ class CrossEntropyLoss(torch.autograd.Function):
n_cols, # shapes
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
SPLIT=world_size > 1,
PRECOMPUTED_LSE=use_precomputed_lse,
num_warps=num_warps,
)
if world_size > 1:
@ -270,11 +281,13 @@ class CrossEntropyLoss(torch.autograd.Function):
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits, None, None, None, None, None, None, None, None
return dlogits, None, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
precomputed_lse: Optional[torch.Tensor] = None,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
@ -302,6 +315,7 @@ def cross_entropy_loss(
return CrossEntropyLoss.apply(
logits,
labels,
precomputed_lse,
label_smoothing,
logit_scale,
lse_square_scale,

View File

@ -1,9 +1,8 @@
import math
# Copyright (c) 2024, Tri Dao.
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@ -13,6 +12,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("precompute_lse", [False, True])
# @pytest.mark.parametrize("precompute_lse", [False])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
@ -22,11 +23,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
@pytest.mark.parametrize("vocab_size", [50257, 128256]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(
vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype
vocab_size,
smoothing,
logit_scale,
lse_square_scale,
return_z_loss,
inplace_backward,
precompute_lse,
dtype,
):
if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):
pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0")
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
@ -48,10 +58,15 @@ def test_cross_entropy_loss(
return_z_loss=return_z_loss,
inplace_backward=inplace_backward,
)
if return_z_loss:
out, out_z_loss = model(x, y)
if precompute_lse:
with torch.no_grad():
lse = torch.logsumexp(x.float(), dim=-1)
else:
out = model(x, y)
lse = None
if return_z_loss:
out, out_z_loss = model(x, y, precomputed_lse=lse)
else:
out = model(x, y, precomputed_lse=lse)
x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()
out_pt = model_pt(x_pt_scaled, y)
if lse_square_scale > 0.0:

View File

@ -15,11 +15,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("precompute_lse", [False, True])
# @pytest.mark.parametrize("precompute_lse", [False])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [0.0])
@pytest.mark.parametrize("logit_scale", [0.7])
# @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
@pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@ -28,8 +30,17 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [2])
def test_cross_entropy_loss_parallel(
vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
vocab_size,
world_size,
smoothing,
logit_scale,
lse_square_scale,
inplace_backward,
precompute_lse,
dtype,
):
if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):
pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0")
assert vocab_size % world_size == 0
rtol, atol = (
(1e-5, 2e-5)
@ -67,7 +78,12 @@ def test_cross_entropy_loss_parallel(
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
if precompute_lse:
with torch.no_grad():
lse = torch.logsumexp(x.float(), dim=-1)
else:
lse = None
out = model(x, y, precomputed_lse=lse)
out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)