[Rotary] Set device before launching Triton kernel to avoid error
This commit is contained in:
parent
6d673cd961
commit
9795159082
@ -205,31 +205,34 @@ def apply_rotary(
|
||||
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
||||
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
||||
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
)
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(x.device.index):
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_baichuan(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user