[Bugfix] Add explicit end_forward calls to flashinfer (#6044)
This commit is contained in:
parent
8e0817c262
commit
c4059ea54f
@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc, self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||
@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.device)
|
||||
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.end_forward()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.paged_kv_indptr,
|
||||
self.paged_kv_indices,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user