Use FP32 in RoPE initialization (#1004)
Co-authored-by: One <imone@tuta.io>
This commit is contained in:
parent
d6770d1f23
commit
e67b4f2c2a
@ -133,9 +133,10 @@ def test_rotary_embedding(
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(
|
||||||
|
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
|||||||
@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base**(
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
|
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
|
||||||
t = torch.arange(max_position, device="cuda").float()
|
t = torch.arange(max_position, dtype=torch.float, device="cuda")
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user