[Rotary] Speed up rotary kernel when interleaved=True
This commit is contained in:
parent
26d7d92f3d
commit
1c523c1ce1
@ -13,7 +13,7 @@ import triton.language as tl
|
||||
# triton.Config({"BLOCK_M": 8}),
|
||||
# triton.Config({"BLOCK_M": 16}),
|
||||
# ],
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
||||
# )
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
@ -49,56 +49,88 @@ def rotary_kernel(
|
||||
pid_head = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rk = tl.arange(0, BLOCK_K // 2)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
else:
|
||||
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
rk_half = tl.arange(0, BLOCK_K // 2)
|
||||
|
||||
X = X + (
|
||||
pid_batch * stride_x_batch
|
||||
+ rm[:, None] * stride_x_seqlen
|
||||
+ pid_head * stride_x_nheads
|
||||
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
|
||||
)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
|
||||
|
||||
cos = tl.load(
|
||||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x1 = tl.load(
|
||||
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
|
||||
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
if not INTERLEAVED:
|
||||
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
||||
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
cos = tl.load(
|
||||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(
|
||||
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x1 = tl.load(
|
||||
X + rotary_dim_half * stride_x_headdim,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x0 * sin + x1 * cos
|
||||
# write back result
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
OUT + rotary_dim_half * stride_out_headdim,
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
)
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x0 * sin + x1 * cos
|
||||
|
||||
# write back result
|
||||
OUT = OUT + (
|
||||
pid_batch * stride_out_batch
|
||||
+ rm[:, None] * stride_out_seqlen
|
||||
+ pid_head * stride_out_nheads
|
||||
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
|
||||
)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
|
||||
)
|
||||
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
||||
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
||||
# Loading x0 will be fast but x1 will be slow.
|
||||
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
||||
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
||||
# and for the odd indices.
|
||||
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
||||
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
||||
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
||||
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
cos = tl.load(
|
||||
COS,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x1 = tl.load(
|
||||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x1 * sin + x0 * cos
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x1 * sin + x0 * cos
|
||||
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
||||
|
||||
|
||||
def apply_rotary(
|
||||
|
||||
@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
|
||||
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
||||
# @pytest.mark.parametrize('rotary_fraction', [1.0])
|
||||
@pytest.mark.parametrize("interleaved", [False, True])
|
||||
# @pytest.mark.parametrize('interleaved', [False])
|
||||
# @pytest.mark.parametrize('interleaved', [True])
|
||||
@pytest.mark.parametrize("inplace", [False, True])
|
||||
# @pytest.mark.parametrize('inplace', [False])
|
||||
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user