diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 1e591295..0d255900 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -133,9 +133,10 @@ def test_rotary_embedding( device="cuda") # Create the rotary embedding. - inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + inv_freq = 1.0 / (base**( + torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) t = torch.arange(max_position).float() - freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) + freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index c35cd8a6..5e9360a3 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention): self.is_neox_style = is_neox_style # Create the cos and sin cache. - inv_freq = 1.0 / (base**( - torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim)) - t = torch.arange(max_position, device="cuda").float() - freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) + inv_freq = 1.0 / (base**(torch.arange( + 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) + t = torch.arange(max_position, dtype=torch.float, device="cuda") + freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1)