fix slicing dimensions
This commit is contained in:
parent
96656b9323
commit
c7c66976cc
@ -46,8 +46,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|||||||
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
||||||
out = torch.empty_like(x) if not inplace else x
|
out = torch.empty_like(x) if not inplace else x
|
||||||
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
||||||
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
|
||||||
if not inplace and rotary_dim < headdim:
|
if not inplace and rotary_dim < headdim:
|
||||||
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||||
ctx.save_for_backward(cos, sin)
|
ctx.save_for_backward(cos, sin)
|
||||||
@ -64,8 +64,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|||||||
do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
|
do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
|
||||||
dx = torch.empty_like(do) if not inplace else do
|
dx = torch.empty_like(do) if not inplace else do
|
||||||
dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
|
dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
|
||||||
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
||||||
if not inplace and rotary_dim < headdim:
|
if not inplace and rotary_dim < headdim:
|
||||||
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
||||||
return dx, None, None, None
|
return dx, None, None, None
|
||||||
@ -90,14 +90,12 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|||||||
rotary_dim *= 2
|
rotary_dim *= 2
|
||||||
assert rotary_dim <= headdim
|
assert rotary_dim <= headdim
|
||||||
assert seqlen <= rotary_seqlen
|
assert seqlen <= rotary_seqlen
|
||||||
assert cos.shape == (seqlen, rotary_dim // 2)
|
|
||||||
assert sin.shape == (seqlen, rotary_dim // 2)
|
|
||||||
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
||||||
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
||||||
ctx.save_for_backward(cos, sin)
|
ctx.save_for_backward(cos, sin)
|
||||||
return qkv
|
return qkv
|
||||||
|
|
||||||
@ -108,11 +106,11 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
rotary_dim *= 2
|
rotary_dim *= 2
|
||||||
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||||
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
||||||
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||||
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
||||||
return dqkv, None, None
|
return dqkv, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user