[TP] Put parallel embeddings in separate modules

This commit is contained in:
Tri Dao 2023-01-02 08:47:48 -08:00
parent 1ec09ebd90
commit 4cab4de5ea

View File

@ -2,6 +2,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange
@ -81,6 +82,51 @@ class BertEmbeddings(nn.Module):
return embeddings
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
f'world_size ({world_size})')
if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx')
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
def forward(self, input: Tensor) -> Tensor:
if self.process_group is None:
return super().forward(input)
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.num_embeddings
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
input = input - vocab_start_index
input[input_ids_mask] = 0
embeddings = super().forward(input)
embeddings[input_ids_mask] = 0.0
return embeddings
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
f'world_size ({world_size})')
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
@ -88,22 +134,17 @@ class ParallelGPT2Embeddings(nn.Module):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
world_size = torch.distributed.get_world_size(process_group)
if vocab_size % world_size != 0:
raise ValueError(f'vocab_size ({vocab_size}) must be divisible by '
f'world_size ({world_size})')
if embed_dim % world_size != 0:
raise ValueError(f'embed_dim ({embed_dim}) must be divisible by '
f'world_size ({world_size})')
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim,
padding_idx=padding_idx, **factory_kwargs)
self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim // world_size, **factory_kwargs
self.position_embeddings = ColumnParallelEmbedding(
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
)
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
@ -113,32 +154,17 @@ class ParallelGPT2Embeddings(nn.Module):
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
if world_size <= 1:
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
if world_size <= 1:
embeddings = embeddings + position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
return embeddings
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.word_embeddings.num_embeddings
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index)
input_ids = input_ids - vocab_start_index
input_ids[input_ids_mask] = 0
embeddings = self.word_embeddings(input_ids)
embeddings[input_ids_mask] = 0.0
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
return reduce_scatter(embeddings, self.process_group)
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)