[FIX] Support non-zero CUDA devices in custom kernels (#1959)
This commit is contained in:
parent
4934d49274
commit
77af974b40
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
@ -36,6 +37,7 @@ void silu_and_mul(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
@ -71,6 +73,7 @@ __global__ void activation_kernel(
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
|
||||
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
@ -33,6 +34,7 @@ void swap_blocks(
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
@ -127,6 +129,7 @@ void copy_blocks(
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
@ -207,6 +210,7 @@ void reshape_and_cache(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
@ -367,6 +371,7 @@ void gather_cached_kv(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
@ -76,6 +77,7 @@ void rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
@ -101,6 +103,7 @@ void fused_add_rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
@ -94,6 +95,7 @@ void rotary_embedding(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
@ -199,7 +200,7 @@ void squeezellm_gemm(
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
|
||||
@ -12,6 +12,7 @@ def create_kv_caches(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
@ -23,7 +24,7 @@ def create_kv_caches(
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device=device)
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
@ -32,7 +33,7 @@ def create_kv_caches(
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device=device)
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
@ -7,22 +7,26 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
|
||||
layer = SiluAndMul()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
@ -33,16 +37,19 @@ def test_silu_and_mul(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gelu_new(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
layer = NewGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
@ -53,15 +60,18 @@ def test_gelu_new(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
layer = FastGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
|
||||
@ -24,6 +24,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device="cuda").int()
|
||||
position_ids = torch.arange(context_len, device=query.device).int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
@ -115,18 +117,19 @@ def test_paged_attention(
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
@ -135,12 +138,12 @@ def test_paged_attention(
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
@ -151,12 +154,12 @@ def test_paged_attention(
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed)
|
||||
seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@ -269,6 +272,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
@ -276,11 +280,12 @@ def test_multi_query_kv_attention(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@ -24,6 +25,7 @@ SEEDS = [0]
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
@ -35,11 +37,12 @@ def test_copy_blocks(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
@ -56,7 +59,7 @@ def test_copy_blocks(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
head_size, dtype, seed, gpu_id)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
@ -88,6 +91,7 @@ def test_copy_blocks(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
@ -98,28 +102,29 @@ def test_reshape_and_cache(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed)
|
||||
seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
|
||||
@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@ -15,6 +16,7 @@ SEEDS = [0]
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
@ -22,14 +24,15 @@ def test_rms_norm(
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype).cuda()
|
||||
gpu_id = f"cuda:{device}"
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@ -23,6 +24,7 @@ SEEDS = [0]
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@ -33,6 +35,7 @@ def test_rotary_embedding(
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@ -40,20 +43,20 @@ def test_rotary_embedding(
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
gpu_id = f"cuda:{device}"
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope = rope.to(dtype).cuda()
|
||||
rope = rope.to(dtype=dtype, device=gpu_id)
|
||||
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=gpu_id)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
|
||||
Loading…
Reference in New Issue
Block a user