From 861c82577dc38ffd8cfc7c6c8e68621210df37d7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 16:41:17 -0700 Subject: [PATCH] [Rotary] Clean up rotary Triton implementation a bit --- flash_attn/ops/triton/rotary.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 68fc296..b526981 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -79,12 +79,10 @@ def rotary_kernel( mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0, ).to(tl.float32) - if not CONJUGATE: - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - else: - o0 = x0 * cos + x1 * sin - o1 = -x0 * sin + x1 * cos + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos # write back result OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) @@ -122,13 +120,11 @@ def rotary_kernel( x1 = tl.load( X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 ).to(tl.float32) - if not CONJUGATE: - o0 = x0 * cos - x1 * sin - o1 = x1 * sin + x0 * cos - else: - o0 = x0 * cos + x1 * sin - o1 = -x1 * sin + x0 * cos - out = tl.where(rk[None, :] % 2 == 0, o0, o1) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))