[BugFix] tensor.get_device() -> tensor.device (#3604)

This commit is contained in:
Kunshang Ji 2024-03-25 10:01:13 +08:00 committed by GitHub
parent 837e185142
commit 6d93d35308
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -108,7 +108,7 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
@ -142,7 +142,7 @@ class RotaryEmbedding(nn.Module):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None: