[Rotary] Clean up rotary Triton implementation a bit
This commit is contained in:
parent
1c523c1ce1
commit
861c82577d
@ -79,12 +79,10 @@ def rotary_kernel(
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x0 * sin + x1 * cos
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
# write back result
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
||||
@ -122,13 +120,11 @@ def rotary_kernel(
|
||||
x1 = tl.load(
|
||||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x1 * sin + x0 * cos
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x1 * sin + x0 * cos
|
||||
out = tl.where(rk[None, :] % 2 == 0, o0, o1)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
x0_cos = x0 * cos
|
||||
x1_sin = x1 * sin
|
||||
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user