diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index da1b049..3c2e307 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -43,8 +43,6 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen - assert cos.shape == (rotary_seqlen, rotary_dim // 2) - assert sin.shape == (rotary_seqlen, rotary_dim // 2) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)