diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4ecac737..4d023282 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr, self.paged_kv_indices, self.paged_kv_last_page_len, @@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata): self.device) assert self.decode_wrapper is not None + self.decode_wrapper.end_forward() self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices,