[FIX] Support non-zero CUDA devices in custom kernels (#1959)

This commit is contained in:
Jee Li 2024-01-03 11:09:59 +08:00 committed by GitHub
parent 4934d49274
commit 77af974b40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 74 additions and 30 deletions

View File

@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
@ -36,6 +37,7 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
@ -71,6 +73,7 @@ __global__ void activation_kernel(
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \

View File

@ -21,6 +21,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the

View File

@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
@ -33,6 +34,7 @@ void swap_blocks(
char *dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
@ -127,6 +129,7 @@ void copy_blocks(
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
@ -207,6 +210,7 @@ void reshape_and_cache(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
@ -367,6 +371,7 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),

View File

@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
@ -76,6 +77,7 @@ void rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
@ -101,6 +103,7 @@ void fused_add_rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),

View File

@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
@ -94,6 +95,7 @@ void rotary_embedding(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),

View File

@ -7,6 +7,7 @@
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
@ -199,7 +200,7 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),

View File

@ -12,6 +12,7 @@ def create_kv_caches(
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
@ -23,7 +24,7 @@ def create_kv_caches(
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device='cuda')
device=device)
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)
@ -32,7 +33,7 @@ def create_kv_caches(
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device='cuda')
device=device)
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches

View File

@ -7,22 +7,26 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
@ -33,16 +37,19 @@ def test_silu_and_mul(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
@ -53,15 +60,18 @@ def test_gelu_new(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)

View File

@ -24,6 +24,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def ref_masked_attention(
@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
position_ids = torch.arange(context_len, device=query.device).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
@ -115,18 +117,19 @@ def test_paged_attention(
block_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
@ -135,12 +138,12 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
device=gpu_id)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@ -151,12 +154,12 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed)
seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
@ -269,6 +272,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
@ -276,11 +280,12 @@ def test_multi_query_kv_attention(
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)

View File

@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@ -24,6 +25,7 @@ SEEDS = [0]
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_copy_blocks(
kv_cache_factory,
@ -35,11 +37,12 @@ def test_copy_blocks(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
@ -56,7 +59,7 @@ def test_copy_blocks(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, dtype, seed)
head_size, dtype, seed, gpu_id)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
@ -88,6 +91,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_reshape_and_cache(
kv_cache_factory,
@ -98,28 +102,29 @@ def test_reshape_and_cache(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
seed)
seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.

View File

@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@ -15,6 +16,7 @@ SEEDS = [0]
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
@ -22,14 +24,15 @@ def test_rms_norm(
add_residual: bool,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
layer = RMSNorm(hidden_size).to(dtype).cuda()
gpu_id = f"cuda:{device}"
layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

View File

@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@ -23,6 +24,7 @@ SEEDS = [0]
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
@ -33,6 +35,7 @@ def test_rotary_embedding(
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
@ -40,20 +43,20 @@ def test_rotary_embedding(
rotary_dim = head_size
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype).cuda()
rope = rope.to(dtype=dtype, device=gpu_id)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cuda")
device=gpu_id)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first