flash-attention/tests/losses/test_cross_entropy.py

55 lines
2.1 KiB
Python
Raw Normal View History

2022-11-13 11:49:33 +08:00
import math
2023-08-19 11:59:35 +08:00
import pytest
2022-11-13 11:49:33 +08:00
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
2022-11-13 11:49:33 +08:00
2023-08-19 11:59:35 +08:00
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
2022-11-13 11:49:33 +08:00
2023-08-19 11:59:35 +08:00
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
2023-08-19 11:59:35 +08:00
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
2023-08-19 11:59:35 +08:00
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
2023-08-19 11:59:35 +08:00
device = "cuda"
2022-11-13 11:49:33 +08:00
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
2023-08-19 11:59:35 +08:00
x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
2022-11-13 11:49:33 +08:00
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
if batch_size * seqlen > 10:
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
model = CrossEntropyLoss(
label_smoothing=smoothing,
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
)
2022-11-13 11:49:33 +08:00
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
2022-11-13 11:49:33 +08:00
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)