Remove redundant shape asserts in rotary embeddings

This commit is contained in:
Alexander Ploshkin 2022-12-15 18:13:21 +04:00
parent 04c4c6106e
commit 96656b9323

View File

@ -43,8 +43,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert cos.shape == (rotary_seqlen, rotary_dim // 2)
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
out = torch.empty_like(x) if not inplace else x
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)