[GPT] Refactor function to shard state_dict for TensorParallel

This commit is contained in:
Tri Dao 2023-01-01 00:09:33 -08:00
parent 65b4064b2a
commit ef1ba918c6
2 changed files with 86 additions and 62 deletions

View File

@ -14,6 +14,8 @@ import torch.nn.functional as F
from transformers import GPT2Config
from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block
@ -338,3 +340,51 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
vocab_size = config.vocab_size
if config.vocab_size % config.pad_vocab_size_multiple != 0:
vocab_size += (config.pad_vocab_size_multiple
- (config.vocab_size % config.pad_vocab_size_multiple))
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key):
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
shard_first_dim(state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
return state_dict

View File

@ -12,7 +12,7 @@ from transformers import GPT2Config
from apex.transformer import parallel_state
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
@ -22,11 +22,11 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [1])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, has_pos_emb, world_size, dtype):
def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
@ -91,45 +91,8 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.transformer.embeddings.word_embeddings.weight.copy_(
model_pt.transformer.embeddings.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
)
if has_pos_emb:
model.transformer.embeddings.position_embeddings.weight.copy_(
model_pt.transformer.embeddings.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
model.transformer.ln_0.weight.copy_(model_pt.transformer.ln_0.weight)
model.transformer.ln_0.bias.copy_(model_pt.transformer.ln_0.bias)
for i in range(num_layers):
model.transformer.layers[i].mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
)
model.transformer.layers[i].mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
)
model.transformer.layers[i].mixer.out_proj.weight.copy_(
model_pt.transformer.layers[i].mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
if rank == 0:
model.transformer.layers[i].mixer.out_proj.bias.copy_(model_pt.transformer.layers[i].mixer.out_proj.bias)
model.transformer.layers[i].mlp.fc1.weight.copy_(
model_pt.transformer.layers[i].mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
model.transformer.layers[i].mlp.fc1.bias.copy_(
model_pt.transformer.layers[i].mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
model.transformer.layers[i].mlp.fc2.weight.copy_(
model_pt.transformer.layers[i].mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
if rank == 0:
model.transformer.layers[i].mlp.fc2.bias.copy_(model_pt.transformer.layers[i].mlp.fc2.bias)
model.transformer.layers[i].norm1.weight.copy_(model_pt.transformer.layers[i].norm1.weight)
model.transformer.layers[i].norm1.bias.copy_(model_pt.transformer.layers[i].norm1.bias)
model.transformer.layers[i].norm2.weight.copy_(model_pt.transformer.layers[i].norm2.weight)
model.transformer.layers[i].norm2.bias.copy_(model_pt.transformer.layers[i].norm2.bias)
# Don't need to copy the lm_head weight since it's tied to the word embedding weight
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights()
with torch.autocast(device_type='cuda', dtype=dtype):
out = model(input_ids[:, :-1]).logits
@ -150,62 +113,73 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()},
config, world_size, rank)
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
model_pt.transformer.embeddings.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
grad_dict['transformer.embeddings.word_embeddings.weight'],
rtol=rtol, atol=atol * 5
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
model_pt.transformer.embeddings.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
grad_dict['transformer.embeddings.position_embeddings.weight'],
rtol=rtol, atol=atol
)
assert torch.allclose(model.transformer.ln_0.weight.grad, model_pt.transformer.ln_0.weight.grad,
assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_0.bias.grad, model_pt.transformer.ln_0.bias.grad,
assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'],
rtol=rtol, atol=atol)
for i in range(num_layers):
# if rank == 0: breakpoint()
# torch.distributed.barrier()
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], 'three o i -> (three o) i'),
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
model_pt.transformer.layers[i].mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'],
rtol=rtol, atol=atol * 10
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, model_pt.transformer.layers[i].mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
model_pt.transformer.layers[i].mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
model_pt.transformer.layers[i].mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
model_pt.transformer.layers[i].mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'],
rtol=rtol, atol=atol * 10
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, model_pt.transformer.layers[i].mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'],
rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, model_pt.transformer.layers[i].norm1.weight.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, model_pt.transformer.layers[i].norm1.bias.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, model_pt.transformer.layers[i].norm2.weight.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, model_pt.transformer.layers[i].norm2.bias.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad,
grad_dict[f'transformer.layers.{i}.norm1.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad,
grad_dict[f'transformer.layers.{i}.norm1.bias'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad,
grad_dict[f'transformer.layers.{i}.norm2.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad,
grad_dict[f'transformer.layers.{i}.norm2.bias'],
rtol=rtol, atol=atol)