[CrossEntropy] Implement logit_scale option

This commit is contained in:
Tri Dao 2023-12-16 18:39:37 -08:00
parent 9356a1c038
commit 08124c8f9c
4 changed files with 32 additions and 11 deletions

View File

@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module):
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
inplace_backward=False,
process_group=None,
@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module):
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.logit_scale = logit_scale
self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward
self.process_group = process_group
@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module):
input,
target,
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index,
inplace_backward=self.inplace_backward,

View File

@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel(
logits_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel(
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
)
) * logit_scale
max_logits = tl.max(logits, 0)
if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel(
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE
):
logits_label = tl.load(logits_ptr + label_idx)
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel(
lse_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel(
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
)
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel(
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignored_index=-100,
inplace_backward=False,
@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function):
logits,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx.save_for_backward(logits, lse, labels)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index
ctx.total_classes = total_classes
@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function):
lse,
labels,
ctx.smoothing,
ctx.logit_scale,
ctx.lse_square_scale,
ctx.ignored_index,
ctx.total_classes,
@ -262,6 +269,7 @@ def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignored_index=-100,
inplace_backward: bool = False,
@ -272,6 +280,7 @@ def cross_entropy_loss(
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
@ -286,6 +295,7 @@ def cross_entropy_loss(
logits,
labels,
label_smoothing,
logit_scale,
lse_square_scale,
ignored_index,
inplace_backward,

View File

@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
def test_cross_entropy_loss(
vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
):
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
model = CrossEntropyLoss(
label_smoothing=smoothing,
logit_scale=logit_scale,
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)

View File

@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [0.0])
@pytest.mark.parametrize("logit_scale", [0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [2])
def test_cross_entropy_loss_parallel(
vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype
vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
):
assert vocab_size % world_size == 0
rtol, atol = (
@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel(
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(
label_smoothing=smoothing,
logit_scale=logit_scale,
reduction="none",
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
out_pt += lse_square_scale * lse_pt.square()
out_pt.masked_fill_(y == -100, 0.0)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)