Remove redundant shape asserts in rotary embeddings
This commit is contained in:
parent
04c4c6106e
commit
96656b9323
@ -43,8 +43,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert cos.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
||||
out = torch.empty_like(x) if not inplace else x
|
||||
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user