[FIX] Fix a bug in initializing Yarn RoPE (#2983)
This commit is contained in:
parent
fd5dcc5c81
commit
c530e2cfe3
@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
|
|||||||
|
|
||||||
|
|
||||||
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype) -> torch.Tensor:
|
||||||
device: torch.device) -> torch.Tensor:
|
|
||||||
if low == high:
|
if low == high:
|
||||||
high += 0.001 # Prevent singularity
|
high += 0.001 # Prevent singularity
|
||||||
|
|
||||||
linear_func = (torch.arange(dim, dtype=dtype, device=device) -
|
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||||
low) / (high - low)
|
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user