[Rotary] Speed up rotary kernel when interleaved=True

This commit is contained in:
Tri Dao 2023-09-03 16:24:37 -07:00
parent 26d7d92f3d
commit 1c523c1ce1
2 changed files with 77 additions and 45 deletions

View File

@ -13,7 +13,7 @@ import triton.language as tl
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# )
@triton.jit
def rotary_kernel(
@ -49,56 +49,88 @@ def rotary_kernel(
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, BLOCK_K)
rk_half = tl.arange(0, BLOCK_K // 2)
X = X + (
pid_batch * stride_x_batch
+ rm[:, None] * stride_x_seqlen
+ pid_head * stride_x_nheads
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
if not INTERLEAVED:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x1 = tl.load(
X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
tl.store(
OUT + rotary_dim_half * stride_out_headdim,
o1,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (
pid_batch * stride_out_batch
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
)
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BLOCK_K) // 2
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
COS,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
tl.float32
)
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x1 * sin + x0 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x1 * sin + x0 * cos
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
def apply_rotary(

View File

@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
# @pytest.mark.parametrize('interleaved', [True])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):