[TP] Put parallel embeddings in separate modules
This commit is contained in:
parent
1ec09ebd90
commit
4cab4de5ea
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
@ -81,6 +82,51 @@ class BertEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if num_embeddings % world_size != 0:
|
||||
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
if world_size > 1 and padding_idx is not None:
|
||||
raise RuntimeError('ParallelEmbedding does not support padding_idx')
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.process_group is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
vocab_size = self.num_embeddings
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
||||
input = input - vocab_start_index
|
||||
input[input_ids_mask] = 0
|
||||
embeddings = super().forward(input)
|
||||
embeddings[input_ids_mask] = 0.0
|
||||
return embeddings
|
||||
|
||||
|
||||
class ColumnParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if embedding_dim % world_size != 0:
|
||||
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
||||
|
||||
|
||||
class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
|
||||
@ -88,22 +134,17 @@ class ParallelGPT2Embeddings(nn.Module):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if vocab_size % world_size != 0:
|
||||
raise ValueError(f'vocab_size ({vocab_size}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
if embed_dim % world_size != 0:
|
||||
raise ValueError(f'embed_dim ({embed_dim}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim,
|
||||
padding_idx=padding_idx, **factory_kwargs)
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(
|
||||
max_position_embeddings, embed_dim // world_size, **factory_kwargs
|
||||
self.position_embeddings = ColumnParallelEmbedding(
|
||||
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
||||
@ -113,32 +154,17 @@ class ParallelGPT2Embeddings(nn.Module):
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
world_size = torch.distributed.get_world_size(self.process_group)
|
||||
if world_size <= 1:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if world_size <= 1:
|
||||
embeddings = embeddings + position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return embeddings
|
||||
else:
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
vocab_size = self.word_embeddings.num_embeddings
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index)
|
||||
input_ids = input_ids - vocab_start_index
|
||||
input_ids[input_ids_mask] = 0
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
embeddings[input_ids_mask] = 0.0
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
else:
|
||||
partition_dim = self.position_embeddings.embedding_dim
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return reduce_scatter(embeddings, self.process_group)
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user