[CORE] Improvement in ranks code (#4718)
This commit is contained in:
parent
a709e87a4f
commit
a7be4d0072
@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
|||||||
"""
|
"""
|
||||||
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
||||||
indices]
|
indices]
|
||||||
return (x > vals[:, None]).long().sum(1).add_(1)
|
result = (x > vals[:, None])
|
||||||
|
del vals
|
||||||
|
return result.sum(1).add_(1)
|
||||||
|
|
||||||
|
|
||||||
def _get_logprobs(
|
def _get_logprobs(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user