[Bugfix] Add explicit end_forward calls to flashinfer (#6044)

This commit is contained in:
Antoni Baum 2024-07-01 16:08:58 -07:00 committed by GitHub
parent 8e0817c262
commit c4059ea54f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata):
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,