[Rotary] Clean up rotary Triton implementation a bit

This commit is contained in:
Tri Dao 2023-09-03 16:41:17 -07:00
parent 1c523c1ce1
commit 861c82577d

View File

@ -79,12 +79,10 @@ def rotary_kernel(
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0, other=0.0,
).to(tl.float32) ).to(tl.float32)
if not CONJUGATE: if CONJUGATE:
o0 = x0 * cos - x1 * sin sin = -sin
o1 = x0 * sin + x1 * cos o0 = x0 * cos - x1 * sin
else: o1 = x0 * sin + x1 * cos
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result # write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
@ -122,13 +120,11 @@ def rotary_kernel(
x1 = tl.load( x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32) ).to(tl.float32)
if not CONJUGATE: if CONJUGATE:
o0 = x0 * cos - x1 * sin sin = -sin
o1 = x1 * sin + x0 * cos x0_cos = x0 * cos
else: x1_sin = x1 * sin
o0 = x0 * cos + x1 * sin out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
o1 = -x1 * sin + x0 * cos
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))