return z_loss (#768)
This commit is contained in:
parent
43ceab630b
commit
d8aacc510c
@ -16,6 +16,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
lse_square_scale=0.0,
|
lse_square_scale=0.0,
|
||||||
inplace_backward=False,
|
inplace_backward=False,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
|
return_z_loss=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -26,7 +27,10 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||||
This saves memory.
|
This saves memory.
|
||||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||||
one part of the vocab. The loss will be aggregated across processes.
|
one part of the vocab. The loss will be aggregated across processes.
|
||||||
|
return_z_loss: bool. If True, we return the component of the loss contributed by
|
||||||
|
the lse_square_scale value. This value is only for logging and does not support
|
||||||
|
backprop.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if reduction not in ["mean", "none", "sum"]:
|
if reduction not in ["mean", "none", "sum"]:
|
||||||
@ -38,6 +42,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
self.lse_square_scale = lse_square_scale
|
self.lse_square_scale = lse_square_scale
|
||||||
self.inplace_backward = inplace_backward
|
self.inplace_backward = inplace_backward
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.return_z_loss = return_z_loss
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
"""
|
"""
|
||||||
@ -46,9 +51,10 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
target: (batch,)
|
target: (batch,)
|
||||||
Returns:
|
Returns:
|
||||||
losses: (batch,) if reduction is 'none', else (1,), dtype float
|
losses: (batch,) if reduction is 'none', else (1,), dtype float
|
||||||
|
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
|
||||||
"""
|
"""
|
||||||
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
|
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
|
||||||
loss = cross_entropy_loss(
|
loss, z_loss = cross_entropy_loss(
|
||||||
input,
|
input,
|
||||||
target,
|
target,
|
||||||
label_smoothing=self.label_smoothing,
|
label_smoothing=self.label_smoothing,
|
||||||
@ -59,8 +65,20 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
)
|
)
|
||||||
if self.reduction == "mean":
|
if self.reduction == "mean":
|
||||||
return loss.sum() / (target != self.ignore_index).sum()
|
loss = loss.sum() / (target != self.ignore_index).sum()
|
||||||
elif self.reduction == "sum":
|
elif self.reduction == "sum":
|
||||||
return loss.sum()
|
loss = loss.sum()
|
||||||
else:
|
else:
|
||||||
|
loss = loss
|
||||||
|
|
||||||
|
if not self.return_z_loss:
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
if self.reduction == "mean":
|
||||||
|
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
|
||||||
|
elif self.reduction == "sum":
|
||||||
|
z_loss = z_loss.sum()
|
||||||
|
else:
|
||||||
|
z_loss = z_loss
|
||||||
|
|
||||||
|
return loss, z_loss
|
||||||
|
|||||||
@ -26,6 +26,7 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
|
|||||||
def cross_entropy_fwd_kernel(
|
def cross_entropy_fwd_kernel(
|
||||||
loss_ptr, # data ptrs
|
loss_ptr, # data ptrs
|
||||||
lse_ptr,
|
lse_ptr,
|
||||||
|
z_loss_ptr,
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
labels_ptr,
|
labels_ptr,
|
||||||
smoothing,
|
smoothing,
|
||||||
@ -57,6 +58,7 @@ def cross_entropy_fwd_kernel(
|
|||||||
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
||||||
if label_idx == ignored_index:
|
if label_idx == ignored_index:
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
z_loss = 0.0
|
||||||
else:
|
else:
|
||||||
label_idx -= class_start_idx
|
label_idx -= class_start_idx
|
||||||
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
||||||
@ -78,8 +80,13 @@ def cross_entropy_fwd_kernel(
|
|||||||
else:
|
else:
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
if not SPLIT:
|
if not SPLIT:
|
||||||
loss += lse_square_scale * lse * lse
|
z_loss = lse_square_scale * lse * lse
|
||||||
|
loss += z_loss
|
||||||
|
else:
|
||||||
|
z_loss = 0.0
|
||||||
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
||||||
|
if not SPLIT:
|
||||||
|
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
@ -172,12 +179,14 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
||||||
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||||
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||||
|
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||||
with torch.cuda.device(logits.device.index):
|
with torch.cuda.device(logits.device.index):
|
||||||
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
||||||
losses, # data ptrs
|
losses, # data ptrs
|
||||||
lse,
|
lse,
|
||||||
|
z_losses,
|
||||||
logits,
|
logits,
|
||||||
labels,
|
labels,
|
||||||
smoothing,
|
smoothing,
|
||||||
@ -219,10 +228,15 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
# Again, we just have to add the (global) lse.
|
# Again, we just have to add the (global) lse.
|
||||||
losses += lse
|
losses += lse
|
||||||
if lse_square_scale != 0.0:
|
if lse_square_scale != 0.0:
|
||||||
losses += lse_square_scale * lse.square()
|
z_losses = lse_square_scale * lse.square()
|
||||||
|
z_losses.masked_fill_(labels == ignored_index, 0.0)
|
||||||
|
losses += z_losses
|
||||||
|
else:
|
||||||
|
z_losses = torch.zeros_like(losses)
|
||||||
losses.masked_fill_(labels == ignored_index, 0.0)
|
losses.masked_fill_(labels == ignored_index, 0.0)
|
||||||
|
|
||||||
ctx.save_for_backward(logits, lse, labels)
|
ctx.save_for_backward(logits, lse, labels)
|
||||||
|
ctx.mark_non_differentiable(z_losses)
|
||||||
ctx.smoothing = smoothing
|
ctx.smoothing = smoothing
|
||||||
ctx.logit_scale = logit_scale
|
ctx.logit_scale = logit_scale
|
||||||
ctx.lse_square_scale = lse_square_scale
|
ctx.lse_square_scale = lse_square_scale
|
||||||
@ -230,10 +244,13 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
ctx.total_classes = total_classes
|
ctx.total_classes = total_classes
|
||||||
ctx.class_start_idx = class_start_idx
|
ctx.class_start_idx = class_start_idx
|
||||||
ctx.inplace_backward = inplace_backward
|
ctx.inplace_backward = inplace_backward
|
||||||
return losses
|
|
||||||
|
return losses, z_losses
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_losses):
|
def backward(ctx, grad_losses, grad_z_losses):
|
||||||
|
del grad_z_losses # z_losses are only for logging.
|
||||||
|
|
||||||
logits, lse, labels = ctx.saved_tensors
|
logits, lse, labels = ctx.saved_tensors
|
||||||
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
|
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
|
||||||
n_rows, n_cols = logits.shape
|
n_rows, n_cols = logits.shape
|
||||||
@ -262,8 +279,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
)
|
)
|
||||||
return dlogits, None, None, None, None, None, None, None
|
return dlogits, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_loss(
|
def cross_entropy_loss(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
@ -287,9 +303,10 @@ def cross_entropy_loss(
|
|||||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||||
This saves memory.
|
This saves memory.
|
||||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||||
one part of the vocab. The loss will be aggregated across processes.
|
one part of the vocab. The loss will be aggregated across processes.
|
||||||
Returns:
|
Returns:
|
||||||
losses: (batch,), float
|
losses: (batch,), float
|
||||||
|
z_losses: (batch,), float
|
||||||
"""
|
"""
|
||||||
return CrossEntropyLoss.apply(
|
return CrossEntropyLoss.apply(
|
||||||
logits,
|
logits,
|
||||||
|
|||||||
@ -16,6 +16,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
|||||||
@pytest.mark.parametrize("inplace_backward", [False, True])
|
@pytest.mark.parametrize("inplace_backward", [False, True])
|
||||||
# @pytest.mark.parametrize("inplace_backward", [False])
|
# @pytest.mark.parametrize("inplace_backward", [False])
|
||||||
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
|
||||||
|
@pytest.mark.parametrize("return_z_loss", [False, True])
|
||||||
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
|
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
|
||||||
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
|
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
|
||||||
# @pytest.mark.parametrize("logit_scale", [1.0])
|
# @pytest.mark.parametrize("logit_scale", [1.0])
|
||||||
@ -24,7 +25,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
|||||||
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
|
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
|
||||||
# @pytest.mark.parametrize("vocab_size", [12])
|
# @pytest.mark.parametrize("vocab_size", [12])
|
||||||
def test_cross_entropy_loss(
|
def test_cross_entropy_loss(
|
||||||
vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
|
vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype
|
||||||
):
|
):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||||
@ -44,14 +45,21 @@ def test_cross_entropy_loss(
|
|||||||
label_smoothing=smoothing,
|
label_smoothing=smoothing,
|
||||||
logit_scale=logit_scale,
|
logit_scale=logit_scale,
|
||||||
lse_square_scale=lse_square_scale,
|
lse_square_scale=lse_square_scale,
|
||||||
|
return_z_loss=return_z_loss,
|
||||||
inplace_backward=inplace_backward,
|
inplace_backward=inplace_backward,
|
||||||
)
|
)
|
||||||
out = model(x, y)
|
if return_z_loss:
|
||||||
|
out, out_z_loss = model(x, y)
|
||||||
|
else:
|
||||||
|
out = model(x, y)
|
||||||
x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()
|
x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()
|
||||||
out_pt = model_pt(x_pt_scaled, y)
|
out_pt = model_pt(x_pt_scaled, y)
|
||||||
if lse_square_scale > 0.0:
|
if lse_square_scale > 0.0:
|
||||||
lse_pt = torch.logsumexp(x_pt_scaled, dim=-1)
|
lse_pt = torch.logsumexp(x_pt_scaled, dim=-1)
|
||||||
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
|
z_loss_pt = lse_square_scale * (lse_pt[y != -100] ** 2).mean()
|
||||||
|
if return_z_loss:
|
||||||
|
assert torch.allclose(out_z_loss, z_loss_pt, rtol=rtol, atol=atol)
|
||||||
|
out_pt += z_loss_pt
|
||||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
g = torch.randn_like(out)
|
g = torch.randn_like(out)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user