[CORE] Improvement in ranks code (#4718)
This commit is contained in:
parent
a709e87a4f
commit
a7be4d0072
@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
||||
indices]
|
||||
return (x > vals[:, None]).long().sum(1).add_(1)
|
||||
result = (x > vals[:, None])
|
||||
del vals
|
||||
return result.sum(1).add_(1)
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user