From 78225c5366dd4c1c743d9e2ff9d9b6e1ffcb03e7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 25 Dec 2022 14:29:53 -0800 Subject: [PATCH] Implement Tensor Parallel for GPT2Embeddings --- flash_attn/modules/embedding.py | 89 ++++++++++++++++++++++-- tests/modules/test_embedding_parallel.py | 84 ++++++++++++++++++++++ 2 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 tests/modules/test_embedding_parallel.py diff --git a/flash_attn/modules/embedding.py b/flash_attn/modules/embedding.py index da21ad3..5dca08b 100644 --- a/flash_attn/modules/embedding.py +++ b/flash_attn/modules/embedding.py @@ -3,18 +3,26 @@ import torch import torch.nn as nn +from einops import rearrange + +from flash_attn.utils.distributed import reduce_scatter + class GPT2Embeddings(nn.Module): - def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None): + def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, + device=None, dtype=None): """ If max_position_embeddings <= 0, there's no position embeddings """ + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim) + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, + **factory_kwargs) def forward(self, input_ids, position_ids=None): """ @@ -34,19 +42,23 @@ class GPT2Embeddings(nn.Module): class BertEmbeddings(nn.Module): def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size, - padding_idx=None): + padding_idx=None, device=None, dtype=None): """ If max_position_embeddings <= 0, there's no position embeddings If type_vocab_size <= 0, there's no token type embeddings """ + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim) + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, + **factory_kwargs) if self.type_vocab_size > 0: - self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim) + self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, + **factory_kwargs) def forward(self, input_ids, position_ids=None, token_type_ids=None): """ @@ -67,3 +79,66 @@ class BertEmbeddings(nn.Module): token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = embeddings + token_type_embeddings return embeddings + + +class ParallelGPT2Embeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, + padding_idx=None, device=None, dtype=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + world_size = torch.distributed.get_world_size(process_group) + if vocab_size % world_size != 0: + raise ValueError(f'vocab_size ({vocab_size}) must be divisible by ' + f'world_size ({world_size})') + if embed_dim % world_size != 0: + raise ValueError(f'embed_dim ({embed_dim}) must be divisible by ' + f'world_size ({world_size})') + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim, + padding_idx=padding_idx, **factory_kwargs) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding( + max_position_embeddings, embed_dim // world_size, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + world_size = torch.distributed.get_world_size(self.process_group) + if world_size <= 1: + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, 'b s d -> (b s) d') + return embeddings + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.word_embeddings.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index) + input_ids = input_ids - vocab_start_index + input_ids[input_ids_mask] = 0 + embeddings = self.word_embeddings(input_ids) + embeddings[input_ids_mask] = 0.0 + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + partition_dim = self.position_embeddings.embedding_dim + embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, 'b s d -> (b s) d') + return reduce_scatter(embeddings, self.process_group) diff --git a/tests/modules/test_embedding_parallel.py b/tests/modules/test_embedding_parallel.py new file mode 100644 index 0000000..d2de870 --- /dev/null +++ b/tests/modules/test_embedding_parallel.py @@ -0,0 +1,84 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from apex.transformer import parallel_state + +from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('has_pos_emb', [True, False]) +# @pytest.mark.parametrize('has_pos_emb', [True]) +@pytest.mark.parametrize('dim', [1024]) +def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): + vocab_size = 50264 + seqlen = 2048 + assert vocab_size % world_size == 0 + assert dim % world_size == 0 + rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 1024 + assert (batch_size * seqlen) % world_size == 0 + input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) + input_ids = input_ids_pt.detach().clone() + + model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, + device=device, dtype=dtype) + model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, + parallel_state.get_tensor_model_parallel_group(), + device=device, dtype=dtype) + partition_vocab_size = vocab_size // world_size + partition_dim = dim // world_size + with torch.no_grad(): + model.word_embeddings.weight.copy_( + model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size] + ) + if has_pos_emb: + model.position_embeddings.weight.copy_( + model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim] + ) + + out = model(input_ids, combine_batch_seqlen_dim=True) + out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d') + partition_batch_dim = batch_size * seqlen // world_size + assert torch.allclose( + out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + + g = torch.randn_like(out_pt) + out_pt.backward(g) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + parallel_state.destroy_model_parallel() + + assert torch.allclose( + model.word_embeddings.weight.grad, + model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size], + rtol=rtol, atol=atol + ) + if has_pos_emb: + assert torch.allclose( + model.position_embeddings.weight.grad, + model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], + rtol=rtol, atol=atol + )