[BugFix] tensor.get_device() -> tensor.device (#3604)
This commit is contained in:
parent
837e185142
commit
6d93d35308
@ -108,7 +108,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||||
if offsets is not None else positions]
|
if offsets is not None else positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
@ -142,7 +142,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||||
# are in-place operations that update the query and key tensors.
|
# are in-place operations that update the query and key tensors.
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user