[Bugfix] Validate SamplingParam n is an int (#8548)
This commit is contained in:
parent
2940afa04e
commit
b28298f2f4
@ -273,9 +273,14 @@ class SamplingParams(
|
|||||||
self._all_stop_token_ids = set(self.stop_token_ids)
|
self._all_stop_token_ids = set(self.stop_token_ids)
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
|
if not isinstance(self.n, int):
|
||||||
|
raise ValueError(f"n must be an int, but is of "
|
||||||
|
f"type {type(self.n)}")
|
||||||
if self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||||
assert isinstance(self.best_of, int)
|
if not isinstance(self.best_of, int):
|
||||||
|
raise ValueError(f'best_of must be an int, but is of '
|
||||||
|
f'type {type(self.best_of)}')
|
||||||
if self.best_of < self.n:
|
if self.best_of < self.n:
|
||||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||||
f"got n={self.n} and best_of={self.best_of}.")
|
f"got n={self.n} and best_of={self.best_of}.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user