From 9795159082f6e6c847db2bf4284fd17326c31fbd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 5 Sep 2023 21:29:03 -0700 Subject: [PATCH] [Rotary] Set device before launching Triton kernel to avoid error --- flash_attn/ops/triton/rotary.py | 57 +++++++++++++++++---------------- tests/models/test_baichuan.py | 3 -- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index ba846a0..0e9b566 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -205,31 +205,34 @@ def apply_rotary( grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) - rotary_kernel[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - rotary_dim, - seqlen_ro, - seqlen // 128, # key for triton cache (limit number of compilations) - output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - BLOCK_K, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M, - ) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) return output diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 1ff6ea7..6658c1c 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size): rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() - # Need this, otherwise the Triton kernel seems to launched from the wrong device. - torch.cuda.set_device(device) - pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config )