[CrossEntropy] Support precomputed LSE
This commit is contained in:
parent
e371bea04f
commit
c7f32a8409
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -44,7 +44,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.process_group = process_group
|
||||
self.return_z_loss = return_z_loss
|
||||
|
||||
def forward(self, input, target):
|
||||
def forward(self, input, target, precomputed_lse=None):
|
||||
"""
|
||||
Arguments:
|
||||
input: (batch, vocab_size)
|
||||
@ -57,6 +57,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
loss, z_loss = cross_entropy_loss(
|
||||
input,
|
||||
target,
|
||||
precomputed_lse=precomputed_lse,
|
||||
label_smoothing=self.label_smoothing,
|
||||
logit_scale=self.logit_scale,
|
||||
lse_square_scale=self.lse_square_scale,
|
||||
|
||||
@ -39,25 +39,29 @@ def cross_entropy_fwd_kernel(
|
||||
HAS_SMOOTHING: tl.constexpr,
|
||||
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
|
||||
SPLIT: tl.constexpr,
|
||||
PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
|
||||
):
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
||||
sum_logits = 0.0 # For smoothing
|
||||
# Statistics for online softmax
|
||||
m_i = -float("inf")
|
||||
l_i = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
cols = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
) * logit_scale
|
||||
if HAS_SMOOTHING:
|
||||
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
|
||||
m_i_new = tl.maximum(m_i, tl.max(logits))
|
||||
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
|
||||
m_i = m_i_new
|
||||
lse = tl.log(l_i) + m_i
|
||||
tl.store(lse_ptr + row_idx, lse)
|
||||
if not PRECOMPUTED_LSE:
|
||||
# Statistics for online softmax
|
||||
m_i = -float("inf")
|
||||
l_i = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
cols = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
) * logit_scale
|
||||
if HAS_SMOOTHING:
|
||||
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
|
||||
m_i_new = tl.maximum(m_i, tl.max(logits))
|
||||
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
|
||||
m_i = m_i_new
|
||||
lse = tl.log(l_i) + m_i
|
||||
tl.store(lse_ptr + row_idx, lse)
|
||||
else:
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
label_idx = tl.load(labels_ptr + row_idx)
|
||||
if label_idx == ignore_index:
|
||||
loss = 0.0
|
||||
@ -135,7 +139,7 @@ def cross_entropy_bwd_kernel(
|
||||
if HAS_SMOOTHING:
|
||||
smooth_positive = 1.0 - smoothing
|
||||
smooth_negative = smoothing / total_classes
|
||||
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
|
||||
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
|
||||
else:
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
|
||||
@ -148,6 +152,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
ctx,
|
||||
logits,
|
||||
labels,
|
||||
precomputed_lse=None,
|
||||
smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
@ -161,6 +166,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
total_classes = world_size * n_cols
|
||||
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
|
||||
class_start_idx = rank * n_cols
|
||||
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
|
||||
|
||||
if logits.stride(-1) != 1:
|
||||
logits = logits.contiguous()
|
||||
@ -172,7 +178,11 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
||||
)
|
||||
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||
if use_precomputed_lse:
|
||||
assert precomputed_lse.shape == (n_rows,)
|
||||
lse = precomputed_lse.contiguous()
|
||||
else:
|
||||
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
@ -192,8 +202,9 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
n_cols, # shapes
|
||||
logits.stride(0), # strides
|
||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||
num_warps=num_warps,
|
||||
SPLIT=world_size > 1,
|
||||
PRECOMPUTED_LSE=use_precomputed_lse,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
if world_size > 1:
|
||||
@ -270,11 +281,13 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return dlogits, None, None, None, None, None, None, None, None
|
||||
return dlogits, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
precomputed_lse: Optional[torch.Tensor] = None,
|
||||
label_smoothing: float = 0.0,
|
||||
logit_scale: float = 1.0,
|
||||
lse_square_scale: float = 0.0,
|
||||
@ -302,6 +315,7 @@ def cross_entropy_loss(
|
||||
return CrossEntropyLoss.apply(
|
||||
logits,
|
||||
labels,
|
||||
precomputed_lse,
|
||||
label_smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import math
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
@ -13,6 +12,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("precompute_lse", [False, True])
|
||||
# @pytest.mark.parametrize("precompute_lse", [False])
|
||||
@pytest.mark.parametrize("inplace_backward", [False, True])
|
||||
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||
@ -22,11 +23,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
# @pytest.mark.parametrize("logit_scale", [1.0])
|
||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
|
||||
@pytest.mark.parametrize("vocab_size", [50257, 128256]) # test vocab larger than 64k for split
|
||||
# @pytest.mark.parametrize("vocab_size", [12])
|
||||
def test_cross_entropy_loss(
|
||||
vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype
|
||||
vocab_size,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
return_z_loss,
|
||||
inplace_backward,
|
||||
precompute_lse,
|
||||
dtype,
|
||||
):
|
||||
if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):
|
||||
pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0")
|
||||
device = "cuda"
|
||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
# set seed
|
||||
@ -48,10 +58,15 @@ def test_cross_entropy_loss(
|
||||
return_z_loss=return_z_loss,
|
||||
inplace_backward=inplace_backward,
|
||||
)
|
||||
if return_z_loss:
|
||||
out, out_z_loss = model(x, y)
|
||||
if precompute_lse:
|
||||
with torch.no_grad():
|
||||
lse = torch.logsumexp(x.float(), dim=-1)
|
||||
else:
|
||||
out = model(x, y)
|
||||
lse = None
|
||||
if return_z_loss:
|
||||
out, out_z_loss = model(x, y, precomputed_lse=lse)
|
||||
else:
|
||||
out = model(x, y, precomputed_lse=lse)
|
||||
x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()
|
||||
out_pt = model_pt(x_pt_scaled, y)
|
||||
if lse_square_scale > 0.0:
|
||||
|
||||
@ -15,11 +15,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("precompute_lse", [False, True])
|
||||
# @pytest.mark.parametrize("precompute_lse", [False])
|
||||
@pytest.mark.parametrize("inplace_backward", [False, True])
|
||||
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||
# @pytest.mark.parametrize("lse_square_scale", [0.0])
|
||||
@pytest.mark.parametrize("logit_scale", [0.7])
|
||||
# @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||
@pytest.mark.parametrize("lse_square_scale", [1e-2])
|
||||
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
|
||||
# @pytest.mark.parametrize("logit_scale", [1.0])
|
||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||
@ -28,8 +30,17 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
# @pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_cross_entropy_loss_parallel(
|
||||
vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
|
||||
vocab_size,
|
||||
world_size,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
inplace_backward,
|
||||
precompute_lse,
|
||||
dtype,
|
||||
):
|
||||
if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):
|
||||
pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0")
|
||||
assert vocab_size % world_size == 0
|
||||
rtol, atol = (
|
||||
(1e-5, 2e-5)
|
||||
@ -67,7 +78,12 @@ def test_cross_entropy_loss_parallel(
|
||||
inplace_backward=inplace_backward,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
)
|
||||
out = model(x, y)
|
||||
if precompute_lse:
|
||||
with torch.no_grad():
|
||||
lse = torch.logsumexp(x.float(), dim=-1)
|
||||
else:
|
||||
lse = None
|
||||
out = model(x, y, precomputed_lse=lse)
|
||||
out_pt = model_pt(x_pt.float() * logit_scale, y)
|
||||
if lse_square_scale > 0.0:
|
||||
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user