[Rotary] Clean up rotary Triton implementation a bit

This commit is contained in:
Tri Dao 2023-09-03 16:41:17 -07:00
parent 1c523c1ce1
commit 861c82577d

View File

@ -79,12 +79,10 @@ def rotary_kernel(
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
if CONJUGATE:
sin = -sin
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
# write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
@ -122,13 +120,11 @@ def rotary_kernel(
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x1 * sin + x0 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x1 * sin + x0 * cos
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
x1_sin = x1 * sin
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))