From 96656b93237423a86a7a95f6270d8b005d3dbbc5 Mon Sep 17 00:00:00 2001 From: Alexander Ploshkin Date: Thu, 15 Dec 2022 18:13:21 +0400 Subject: [PATCH 1/3] Remove redundant shape asserts in rotary embeddings --- flash_attn/layers/rotary.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index da1b049..3c2e307 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -43,8 +43,6 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen - assert cos.shape == (rotary_seqlen, rotary_dim // 2) - assert sin.shape == (rotary_seqlen, rotary_dim // 2) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) From c7c66976cce3b1fa9281ddca02970bd383de9ebc Mon Sep 17 00:00:00 2001 From: Alexander Ploshkin Date: Fri, 16 Dec 2022 15:39:06 +0400 Subject: [PATCH 2/3] fix slicing dimensions --- flash_attn/layers/rotary.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 3c2e307..f9af93c 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -46,8 +46,8 @@ class ApplyRotaryEmb(torch.autograd.Function): x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) - rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False) + rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False) if not inplace and rotary_dim < headdim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) ctx.save_for_backward(cos, sin) @@ -64,8 +64,8 @@ class ApplyRotaryEmb(torch.autograd.Function): do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1) dx = torch.empty_like(do) if not inplace else do dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2) - rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True) + rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True) if not inplace and rotary_dim < headdim: dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) return dx, None, None, None @@ -90,14 +90,12 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen - assert cos.shape == (seqlen, rotary_dim // 2) - assert sin.shape == (seqlen, rotary_dim // 2) q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False) + rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False) + rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False) ctx.save_for_backward(cos, sin) return qkv @@ -108,11 +106,11 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_dim = cos.shape[-1] rotary_dim *= 2 dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True) + rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'), - rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True) + rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True) return dqkv, None, None From ee8984d2bed01f6dbf52db70d0f3cb815c20d505 Mon Sep 17 00:00:00 2001 From: Alexander Ploshkin Date: Sat, 17 Dec 2022 13:34:57 +0400 Subject: [PATCH 3/3] add asserts for sin shape --- flash_attn/layers/rotary.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index f9af93c..1605158 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) out = torch.empty_like(x) if not inplace else x o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) @@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)