[Rotary] Clean up rotary Triton implementation a bit
This commit is contained in:
parent
1c523c1ce1
commit
861c82577d
@ -79,12 +79,10 @@ def rotary_kernel(
|
|||||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
if not CONJUGATE:
|
if CONJUGATE:
|
||||||
o0 = x0 * cos - x1 * sin
|
sin = -sin
|
||||||
o1 = x0 * sin + x1 * cos
|
o0 = x0 * cos - x1 * sin
|
||||||
else:
|
o1 = x0 * sin + x1 * cos
|
||||||
o0 = x0 * cos + x1 * sin
|
|
||||||
o1 = -x0 * sin + x1 * cos
|
|
||||||
# write back result
|
# write back result
|
||||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
||||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
||||||
@ -122,13 +120,11 @@ def rotary_kernel(
|
|||||||
x1 = tl.load(
|
x1 = tl.load(
|
||||||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
if not CONJUGATE:
|
if CONJUGATE:
|
||||||
o0 = x0 * cos - x1 * sin
|
sin = -sin
|
||||||
o1 = x1 * sin + x0 * cos
|
x0_cos = x0 * cos
|
||||||
else:
|
x1_sin = x1 * sin
|
||||||
o0 = x0 * cos + x1 * sin
|
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
||||||
o1 = -x1 * sin + x0 * cos
|
|
||||||
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
|
|
||||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
||||||
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user