[BugFix] tensor.get_device() -> tensor.device (#3604)
This commit is contained in:
parent
837e185142
commit
6d93d35308
@ -108,7 +108,7 @@ class RotaryEmbedding(nn.Module):
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -142,7 +142,7 @@ class RotaryEmbedding(nn.Module):
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user