[CORE] Improvement in ranks code (#4718)

This commit is contained in:
Swapnil Parekh 2024-05-12 20:47:47 -04:00 committed by GitHub
parent a709e87a4f
commit a7be4d0072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
return (x > vals[:, None]).long().sum(1).add_(1)
result = (x > vals[:, None])
del vals
return result.sum(1).add_(1)
def _get_logprobs(