From 03ffd0a02251e10c1aa14fca8cb0ab1e4e40b886 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 26 Sep 2023 10:48:33 -0700 Subject: [PATCH] Add comments on RoPE initialization (#1176) --- vllm/model_executor/layers/attention.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 5e9360a3..a60f7b7b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -264,6 +264,15 @@ class PagedAttentionWithRoPE(PagedAttention): self.is_neox_style = is_neox_style # Create the cos and sin cache. + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / (base**(torch.arange( 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) t = torch.arange(max_position, dtype=torch.float, device="cuda") @@ -274,7 +283,6 @@ class PagedAttentionWithRoPE(PagedAttention): # FIXME(woosuk): This assumes that we configure the default dtype when # initializing the model. - # TODO(woosuk): Make it more robust. torch_dtype = torch.get_default_dtype() cache = cache.to(torch_dtype) # Embedding size: [max_position, rotary_dim]