[CrossEntropy] Implement logit_scale option
This commit is contained in:
parent
9356a1c038
commit
08124c8f9c
@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
ignore_index=-100,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.logit_scale = logit_scale
|
||||
self.lse_square_scale = lse_square_scale
|
||||
self.inplace_backward = inplace_backward
|
||||
self.process_group = process_group
|
||||
@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
input,
|
||||
target,
|
||||
label_smoothing=self.label_smoothing,
|
||||
logit_scale=self.logit_scale,
|
||||
lse_square_scale=self.lse_square_scale,
|
||||
ignored_index=self.ignore_index,
|
||||
inplace_backward=self.inplace_backward,
|
||||
|
||||
@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel(
|
||||
logits_ptr,
|
||||
labels_ptr,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel(
|
||||
label_idx = tl.load(labels_ptr + row_idx)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
) * logit_scale
|
||||
max_logits = tl.max(logits, 0)
|
||||
if HAS_SMOOTHING:
|
||||
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
||||
@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel(
|
||||
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
||||
n_cols, (col_block_idx + 1) * BLOCK_SIZE
|
||||
):
|
||||
logits_label = tl.load(logits_ptr + label_idx)
|
||||
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
||||
if HAS_SMOOTHING:
|
||||
loss = (
|
||||
(lse if not SPLIT else 0.0)
|
||||
@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel(
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel(
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
) * logit_scale
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
probs += 2.0 * lse_square_scale * lse * probs
|
||||
@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel(
|
||||
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
|
||||
else:
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
|
||||
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
logits,
|
||||
labels,
|
||||
smoothing,
|
||||
smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
ignored_index=-100,
|
||||
inplace_backward=False,
|
||||
@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
logits,
|
||||
labels,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.logit_scale = logit_scale
|
||||
ctx.lse_square_scale = lse_square_scale
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.total_classes = total_classes
|
||||
@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
lse,
|
||||
labels,
|
||||
ctx.smoothing,
|
||||
ctx.logit_scale,
|
||||
ctx.lse_square_scale,
|
||||
ctx.ignored_index,
|
||||
ctx.total_classes,
|
||||
@ -262,6 +269,7 @@ def cross_entropy_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_smoothing: float = 0.0,
|
||||
logit_scale: float = 1.0,
|
||||
lse_square_scale: float = 0.0,
|
||||
ignored_index=-100,
|
||||
inplace_backward: bool = False,
|
||||
@ -272,6 +280,7 @@ def cross_entropy_loss(
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
label_smoothing: float
|
||||
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||
This is also referred to as "z-loss".
|
||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||
@ -286,6 +295,7 @@ def cross_entropy_loss(
|
||||
logits,
|
||||
labels,
|
||||
label_smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
inplace_backward,
|
||||
|
||||
@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
|
||||
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
|
||||
# @pytest.mark.parametrize("logit_scale", [1.0])
|
||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
|
||||
# @pytest.mark.parametrize("vocab_size", [12])
|
||||
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
|
||||
def test_cross_entropy_loss(
|
||||
vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
# set seed
|
||||
@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac
|
||||
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
|
||||
model = CrossEntropyLoss(
|
||||
label_smoothing=smoothing,
|
||||
logit_scale=logit_scale,
|
||||
lse_square_scale=lse_square_scale,
|
||||
inplace_backward=inplace_backward,
|
||||
)
|
||||
out = model(x, y)
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
out_pt = model_pt(x_pt.float() * logit_scale, y)
|
||||
if lse_square_scale > 0.0:
|
||||
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
|
||||
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
|
||||
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
|
||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||
# @pytest.mark.parametrize("lse_square_scale", [0.0])
|
||||
@pytest.mark.parametrize("logit_scale", [0.7])
|
||||
# @pytest.mark.parametrize("logit_scale", [1.0])
|
||||
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
|
||||
# @pytest.mark.parametrize("smoothing", [0.0])
|
||||
@pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
|
||||
@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
# @pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_cross_entropy_loss_parallel(
|
||||
vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype
|
||||
vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
|
||||
):
|
||||
assert vocab_size % world_size == 0
|
||||
rtol, atol = (
|
||||
@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel(
|
||||
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
|
||||
model = CrossEntropyLoss(
|
||||
label_smoothing=smoothing,
|
||||
logit_scale=logit_scale,
|
||||
reduction="none",
|
||||
lse_square_scale=lse_square_scale,
|
||||
inplace_backward=inplace_backward,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
)
|
||||
out = model(x, y)
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
out_pt = model_pt(x_pt.float() * logit_scale, y)
|
||||
if lse_square_scale > 0.0:
|
||||
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
|
||||
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
|
||||
out_pt += lse_square_scale * lse_pt.square()
|
||||
out_pt.masked_fill_(y == -100, 0.0)
|
||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user