[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)

Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
Yang Zheng 2024-11-04 16:54:37 +08:00 committed by GitHub
parent b67feb1274
commit 4dbcbbeb09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 36 deletions

View File

@ -1,6 +1,7 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
device, self.runner.pin_memory) device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory) self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory) device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
dtype=torch.int32, device,
device=device) self.runner.pin_memory)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
dtype=torch.int32, device, self.runner.pin_memory)
device=device)
placeholder_index_maps = { placeholder_index_maps = {
modality: placeholder_map.index_map() modality: placeholder_map.index_map()
for modality, placeholder_map in for modality, placeholder_map in
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
return FlashAttentionMetadata( return FlashAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,

View File

@ -1,6 +1,7 @@
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device, self.runner.pin_memory) device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory) self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory) device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
dtype=torch.int32, device,
device=device) self.runner.pin_memory)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
dtype=torch.int32, device, self.runner.pin_memory)
device=device)
placeholder_index_maps = { placeholder_index_maps = {
modality: placeholder_map.index_map() modality: placeholder_map.index_map()
for modality, placeholder_map in for modality, placeholder_map in
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
return self._metadata_cls( # type: ignore return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_query_len=max_query_len, max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,