diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 6f38e923..365bbd5e 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &context_lens, int block_size, int max_context_len, const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype) { + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &context_lens, int block_size, int max_context_len, const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype) { + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 94f5affc..7849a5df 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -111,7 +111,9 @@ void copy_blocks(std::vector &key_caches, void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype) { + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); + int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 4f69ebef..9706e191 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -114,6 +114,7 @@ class TorchSDPABackendImpl(AttentionImpl): value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -138,7 +139,8 @@ class TorchSDPABackendImpl(AttentionImpl): PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + attn_metadata.kv_cache_dtype, + kv_scale) if attn_metadata.is_prompt: if (kv_cache is None or attn_metadata.block_tables.numel() == 0): @@ -199,6 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl): self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index ec2c18dc..256bffdf 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -97,7 +97,7 @@ class PagedAttention: num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale, + kv_scale: float, ) -> torch.Tensor: output = torch.empty_like(query)