2022-11-29 09:31:19 +08:00
|
|
|
# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py
|
|
|
|
|
# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))
|
|
|
|
|
# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py
|
|
|
|
|
# But we pass in the loss to avoid recomputation
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torchmetrics import Metric
|
|
|
|
|
|
|
|
|
|
try:
|
2023-01-01 14:43:28 +08:00
|
|
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
2022-11-29 09:31:19 +08:00
|
|
|
except ImportError:
|
|
|
|
|
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
|
|
|
|
|
|
|
|
|
__all__ = ['Perplexity']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Perplexity(Metric):
|
|
|
|
|
r"""
|
|
|
|
|
Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits
|
|
|
|
|
per word a model needs to represent the sample.
|
|
|
|
|
Args:
|
|
|
|
|
kwargs:
|
|
|
|
|
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import torch
|
|
|
|
|
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
|
|
|
|
|
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
|
|
|
|
|
>>> target[0, 6:] = -100
|
|
|
|
|
>>> metric = Perplexity(ignore_index=-100)
|
|
|
|
|
>>> metric(preds, target)
|
|
|
|
|
tensor(5.2545)
|
|
|
|
|
"""
|
|
|
|
|
is_differentiable = True
|
|
|
|
|
higher_is_better = False
|
|
|
|
|
full_state_update = False
|
|
|
|
|
total_log_probs: Tensor
|
|
|
|
|
count: Tensor
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs: Dict[str, Any]):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64),
|
|
|
|
|
dist_reduce_fx="sum")
|
|
|
|
|
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum")
|
|
|
|
|
|
|
|
|
|
self.loss_fn = CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
|
|
|
|
|
"""Compute and store intermediate statistics for Perplexity.
|
|
|
|
|
Args:
|
|
|
|
|
preds:
|
|
|
|
|
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
|
|
|
|
|
target:
|
|
|
|
|
Ground truth values with a shape [batch_size, seq_len].
|
|
|
|
|
"""
|
|
|
|
|
count = target.numel()
|
|
|
|
|
if loss is None:
|
|
|
|
|
loss = self.loss_fn(preds, target)
|
|
|
|
|
self.total_log_probs += loss.double() * count
|
|
|
|
|
self.count += count
|
|
|
|
|
|
|
|
|
|
def compute(self) -> Tensor:
|
|
|
|
|
"""Compute the Perplexity.
|
|
|
|
|
Returns:
|
|
|
|
|
Perplexity
|
|
|
|
|
"""
|
|
|
|
|
return torch.exp(self.total_log_probs / self.count)
|