From c4059ea54ff36e62b03f1a88baa41ca72dc695e4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 1 Jul 2024 16:08:58 -0700 Subject: [PATCH] [Bugfix] Add explicit `end_forward` calls to flashinfer (#6044) --- vllm/attention/backends/flashinfer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4ecac737..4d023282 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr, self.paged_kv_indices, self.paged_kv_last_page_len, @@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata): self.device) assert self.decode_wrapper is not None + self.decode_wrapper.end_forward() self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices,