[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)
Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
parent
b67feb1274
commit
4dbcbbeb09
@ -1,6 +1,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
return FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user