add asserts for sin shape
This commit is contained in:
parent
c7c66976cc
commit
ee8984d2be
@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
||||
out = torch.empty_like(x) if not inplace else x
|
||||
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
||||
@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user