[Rotary] Set device before launching Triton kernel to avoid error

This commit is contained in:
Tri Dao 2023-09-05 21:29:03 -07:00
parent 6d673cd961
commit 9795159082
2 changed files with 30 additions and 30 deletions

View File

@ -205,31 +205,34 @@ def apply_rotary(
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M,
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M,
)
return output

View File

@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)