diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index bd05258..4ec049e 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -193,16 +193,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): sin_k = sin if sin_k is None else sin_k dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] apply_rotary( - dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True + dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True ) apply_rotary( dk, cos_k, sin_k, seqlen_offsets, - interleaved=interleaved, + interleaved=ctx.interleaved, inplace=True, - conjudate=True, + conjugate=True, ) return dqkv, None, None, None, None, None, None