[Bugfix] Add explicit end_forward calls to flashinfer (#6044)
This commit is contained in:
parent
8e0817c262
commit
c4059ea54f
@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||||
self.device)
|
self.device)
|
||||||
|
self.prefill_wrapper.end_forward()
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
self.query_start_loc, self.paged_kv_indptr,
|
self.query_start_loc, self.paged_kv_indptr,
|
||||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||||
@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.device)
|
self.device)
|
||||||
|
|
||||||
assert self.decode_wrapper is not None
|
assert self.decode_wrapper is not None
|
||||||
|
self.decode_wrapper.end_forward()
|
||||||
self.decode_wrapper.begin_forward(
|
self.decode_wrapper.begin_forward(
|
||||||
self.paged_kv_indptr,
|
self.paged_kv_indptr,
|
||||||
self.paged_kv_indices,
|
self.paged_kv_indices,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user