From 96656b93237423a86a7a95f6270d8b005d3dbbc5 Mon Sep 17 00:00:00 2001 From: Alexander Ploshkin Date: Thu, 15 Dec 2022 18:13:21 +0400 Subject: [PATCH] Remove redundant shape asserts in rotary embeddings --- flash_attn/layers/rotary.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index da1b049..3c2e307 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -43,8 +43,6 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen - assert cos.shape == (rotary_seqlen, rotary_dim // 2) - assert sin.shape == (rotary_seqlen, rotary_dim // 2) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)