From ee8984d2bed01f6dbf52db70d0f3cb815c20d505 Mon Sep 17 00:00:00 2001 From: Alexander Ploshkin Date: Sat, 17 Dec 2022 13:34:57 +0400 Subject: [PATCH] add asserts for sin shape --- flash_attn/layers/rotary.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index f9af93c..1605158 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) @@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)