Implement Tensor Parallel for GPT2Embeddings
This commit is contained in:
parent
a8cfe51551
commit
78225c5366
@ -3,18 +3,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.utils.distributed import reduce_scatter
|
||||
|
||||
|
||||
class GPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None):
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
|
||||
device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim)
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
"""
|
||||
@ -34,19 +42,23 @@ class GPT2Embeddings(nn.Module):
|
||||
class BertEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
|
||||
padding_idx=None):
|
||||
padding_idx=None, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If type_vocab_size <= 0, there's no token type embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim)
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
if self.type_vocab_size > 0:
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim)
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
|
||||
**factory_kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
"""
|
||||
@ -67,3 +79,66 @@ class BertEmbeddings(nn.Module):
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
|
||||
padding_idx=None, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if vocab_size % world_size != 0:
|
||||
raise ValueError(f'vocab_size ({vocab_size}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
if embed_dim % world_size != 0:
|
||||
raise ValueError(f'embed_dim ({embed_dim}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim,
|
||||
padding_idx=padding_idx, **factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(
|
||||
max_position_embeddings, embed_dim // world_size, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
world_size = torch.distributed.get_world_size(self.process_group)
|
||||
if world_size <= 1:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return embeddings
|
||||
else:
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
vocab_size = self.word_embeddings.num_embeddings
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index)
|
||||
input_ids = input_ids - vocab_start_index
|
||||
input_ids[input_ids_mask] = 0
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
embeddings[input_ids_mask] = 0.0
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
partition_dim = self.position_embeddings.embedding_dim
|
||||
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
return reduce_scatter(embeddings, self.process_group)
|
||||
|
||||
84
tests/modules/test_embedding_parallel.py
Normal file
84
tests/modules/test_embedding_parallel.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('has_pos_emb', [True, False])
|
||||
# @pytest.mark.parametrize('has_pos_emb', [True])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
|
||||
vocab_size = 50264
|
||||
seqlen = 2048
|
||||
assert vocab_size % world_size == 0
|
||||
assert dim % world_size == 0
|
||||
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 1024
|
||||
assert (batch_size * seqlen) % world_size == 0
|
||||
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
|
||||
input_ids = input_ids_pt.detach().clone()
|
||||
|
||||
model_pt = GPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
|
||||
device=device, dtype=dtype)
|
||||
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
|
||||
parallel_state.get_tensor_model_parallel_group(),
|
||||
device=device, dtype=dtype)
|
||||
partition_vocab_size = vocab_size // world_size
|
||||
partition_dim = dim // world_size
|
||||
with torch.no_grad():
|
||||
model.word_embeddings.weight.copy_(
|
||||
model_pt.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
|
||||
)
|
||||
if has_pos_emb:
|
||||
model.position_embeddings.weight.copy_(
|
||||
model_pt.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
|
||||
)
|
||||
|
||||
out = model(input_ids, combine_batch_seqlen_dim=True)
|
||||
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
|
||||
partition_batch_dim = batch_size * seqlen // world_size
|
||||
assert torch.allclose(
|
||||
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
g = torch.randn_like(out_pt)
|
||||
out_pt.backward(g)
|
||||
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
assert torch.allclose(
|
||||
model.word_embeddings.weight.grad,
|
||||
model_pt.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
if has_pos_emb:
|
||||
assert torch.allclose(
|
||||
model.position_embeddings.weight.grad,
|
||||
model_pt.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user