diff --git a/flash_attn/modules/embedding.py b/flash_attn/modules/embedding.py index 5dca08b..0db86c0 100644 --- a/flash_attn/modules/embedding.py +++ b/flash_attn/modules/embedding.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from torch import Tensor from einops import rearrange @@ -81,6 +82,51 @@ class BertEmbeddings(nn.Module): return embeddings +class VocabParallelEmbedding(nn.Embedding): + + def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if num_embeddings % world_size != 0: + raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by ' + f'world_size ({world_size})') + if world_size > 1 and padding_idx is not None: + raise RuntimeError('ParallelEmbedding does not support padding_idx') + else: + world_size = 1 + super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.process_group is None: + return super().forward(input) + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) + input = input - vocab_start_index + input[input_ids_mask] = 0 + embeddings = super().forward(input) + embeddings[input_ids_mask] = 0.0 + return embeddings + + +class ColumnParallelEmbedding(nn.Embedding): + + def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if embedding_dim % world_size != 0: + raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by ' + f'world_size ({world_size})') + else: + world_size = 1 + super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) + + class ParallelGPT2Embeddings(nn.Module): def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, @@ -88,22 +134,17 @@ class ParallelGPT2Embeddings(nn.Module): """ If max_position_embeddings <= 0, there's no position embeddings """ - world_size = torch.distributed.get_world_size(process_group) - if vocab_size % world_size != 0: - raise ValueError(f'vocab_size ({vocab_size}) must be divisible by ' - f'world_size ({world_size})') - if embed_dim % world_size != 0: - raise ValueError(f'embed_dim ({embed_dim}) must be divisible by ' - f'world_size ({world_size})') factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.process_group = process_group - self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim, - padding_idx=padding_idx, **factory_kwargs) + self.word_embeddings = VocabParallelEmbedding( + vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, + **factory_kwargs + ) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding( - max_position_embeddings, embed_dim // world_size, **factory_kwargs + self.position_embeddings = ColumnParallelEmbedding( + max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs ) def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): @@ -113,32 +154,17 @@ class ParallelGPT2Embeddings(nn.Module): """ batch_size, seqlen = input_ids.shape world_size = torch.distributed.get_world_size(self.process_group) - if world_size <= 1: - embeddings = self.word_embeddings(input_ids) - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + if world_size <= 1: embeddings = embeddings + position_embeddings - if combine_batch_seqlen_dim: - embeddings = rearrange(embeddings, 'b s d -> (b s) d') - return embeddings - else: - rank = torch.distributed.get_rank(self.process_group) - vocab_size = self.word_embeddings.num_embeddings - vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size - # Create a mask of valid vocab ids (1 means it needs to be masked). - input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index) - input_ids = input_ids - vocab_start_index - input_ids[input_ids_mask] = 0 - embeddings = self.word_embeddings(input_ids) - embeddings[input_ids_mask] = 0.0 - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) + else: partition_dim = self.position_embeddings.embedding_dim + rank = torch.distributed.get_rank(self.process_group) embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings - if combine_batch_seqlen_dim: - embeddings = rearrange(embeddings, 'b s d -> (b s) d') - return reduce_scatter(embeddings, self.process_group) + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, 'b s d -> (b s) d') + return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)