[TPU] Optimize RoPE forward_native2 (#7636)
This commit is contained in:
parent
0c2fa50b84
commit
ab7165f2c7
@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def _apply_rotary_emb(
|
def _apply_rotary_emb(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_ = torch.view_as_complex(
|
"""
|
||||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
Args:
|
||||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
x: [num_tokens, num_heads, head_size]
|
||||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
cos: [num_tokens, head_size // 2]
|
||||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
sin: [num_tokens, head_size // 2]
|
||||||
-1).transpose(1, 2)
|
"""
|
||||||
return x_out
|
orig_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
|
cos = cos.unsqueeze(-2)
|
||||||
|
sin = sin.unsqueeze(-2)
|
||||||
|
o1 = x1 * cos - x2 * sin
|
||||||
|
o2 = x2 * cos + x1 * sin
|
||||||
|
return torch.cat((o1, o2), dim=-1).to(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbedding(CustomOp):
|
||||||
@ -78,14 +86,10 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
|
cache = cache.to(dtype)
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
self.use_native2 = current_platform.is_tpu() and is_neox_style
|
self.use_native2 = current_platform.is_tpu() and is_neox_style
|
||||||
if not self.use_native2:
|
|
||||||
cache = cache.to(dtype)
|
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
||||||
else:
|
|
||||||
cos, sin = cache.chunk(2, dim=-1)
|
|
||||||
freqs_cis = cos + 1j * sin
|
|
||||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
@ -173,28 +177,25 @@ class RotaryEmbedding(CustomOp):
|
|||||||
|
|
||||||
This method might perform better than `forward_native()` when compiled.
|
This method might perform better than `forward_native()` when compiled.
|
||||||
"""
|
"""
|
||||||
if positions.dim() == 1:
|
|
||||||
batch_size = 1
|
|
||||||
seq_len = positions.shape[0]
|
|
||||||
else:
|
|
||||||
batch_size, seq_len = positions.shape
|
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
positions = positions + offsets
|
positions = positions + offsets
|
||||||
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
|
positions = positions.flatten()
|
||||||
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
|
num_tokens = positions.shape[0]
|
||||||
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
query = query.view(batch_size, seq_len, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
|
query_rot = _apply_rotary_emb(query_rot, cos, sin)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(batch_size, seq_len, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
|
key_rot = _apply_rotary_emb(key_rot, cos, sin)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user