[GPT] Refactor function to shard state_dict for TensorParallel
This commit is contained in:
parent
65b4064b2a
commit
ef1ba918c6
@ -14,6 +14,8 @@ import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
@ -338,3 +340,51 @@ def remap_state_dict_gpt2(state_dict, config):
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
vocab_size = config.vocab_size
|
||||
if config.vocab_size % config.pad_vocab_size_multiple != 0:
|
||||
vocab_size += (config.pad_vocab_size_multiple
|
||||
- (config.vocab_size % config.pad_vocab_size_multiple))
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
def shard_first_dim(state_dict, key):
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size
|
||||
state_dict[key] = x[rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_last_dim(state_dict, key):
|
||||
x = state_dict[key]
|
||||
dim = x.shape[-1] // world_size
|
||||
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_qkv_headdim(state_dict, key):
|
||||
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
|
||||
dim = x.shape[1] // world_size
|
||||
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
|
||||
'three d ... -> (three d) ...')
|
||||
|
||||
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
|
||||
if 'lm_head.weight' in state_dict:
|
||||
shard_first_dim(state_dict, 'lm_head.weight')
|
||||
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
|
||||
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
|
||||
for i in range(config.num_hidden_layers):
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
|
||||
return state_dict
|
||||
|
||||
@ -12,7 +12,7 @@ from transformers import GPT2Config
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
|
||||
|
||||
@ -22,11 +22,11 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [1])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('has_pos_emb', [True, False])
|
||||
# @pytest.mark.parametrize('has_pos_emb', [True])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_block_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
head_dim = 64
|
||||
assert dim % head_dim == 0
|
||||
num_heads = dim // head_dim
|
||||
@ -91,45 +91,8 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
partition_dim = dim // world_size
|
||||
partition_hidden_dim = 4 * dim // world_size
|
||||
with torch.no_grad():
|
||||
model.transformer.embeddings.word_embeddings.weight.copy_(
|
||||
model_pt.transformer.embeddings.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
|
||||
)
|
||||
if has_pos_emb:
|
||||
model.transformer.embeddings.position_embeddings.weight.copy_(
|
||||
model_pt.transformer.embeddings.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
|
||||
)
|
||||
model.transformer.ln_0.weight.copy_(model_pt.transformer.ln_0.weight)
|
||||
model.transformer.ln_0.bias.copy_(model_pt.transformer.ln_0.bias)
|
||||
for i in range(num_layers):
|
||||
model.transformer.layers[i].mixer.Wqkv.weight.copy_(
|
||||
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
'three o i -> (three o) i')
|
||||
)
|
||||
model.transformer.layers[i].mixer.Wqkv.bias.copy_(
|
||||
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
'three o -> (three o)')
|
||||
)
|
||||
model.transformer.layers[i].mixer.out_proj.weight.copy_(
|
||||
model_pt.transformer.layers[i].mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
|
||||
)
|
||||
if rank == 0:
|
||||
model.transformer.layers[i].mixer.out_proj.bias.copy_(model_pt.transformer.layers[i].mixer.out_proj.bias)
|
||||
model.transformer.layers[i].mlp.fc1.weight.copy_(
|
||||
model_pt.transformer.layers[i].mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
|
||||
)
|
||||
model.transformer.layers[i].mlp.fc1.bias.copy_(
|
||||
model_pt.transformer.layers[i].mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
|
||||
)
|
||||
model.transformer.layers[i].mlp.fc2.weight.copy_(
|
||||
model_pt.transformer.layers[i].mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
|
||||
)
|
||||
if rank == 0:
|
||||
model.transformer.layers[i].mlp.fc2.bias.copy_(model_pt.transformer.layers[i].mlp.fc2.bias)
|
||||
model.transformer.layers[i].norm1.weight.copy_(model_pt.transformer.layers[i].norm1.weight)
|
||||
model.transformer.layers[i].norm1.bias.copy_(model_pt.transformer.layers[i].norm1.bias)
|
||||
model.transformer.layers[i].norm2.weight.copy_(model_pt.transformer.layers[i].norm2.weight)
|
||||
model.transformer.layers[i].norm2.bias.copy_(model_pt.transformer.layers[i].norm2.bias)
|
||||
# Don't need to copy the lm_head weight since it's tied to the word embedding weight
|
||||
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
|
||||
model.tie_weights()
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||
out = model(input_ids[:, :-1]).logits
|
||||
@ -150,62 +113,73 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
|
||||
allreduce_sequence_parallel_grad(model, process_group)
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()},
|
||||
config, world_size, rank)
|
||||
|
||||
assert torch.allclose(
|
||||
model.transformer.embeddings.word_embeddings.weight.grad,
|
||||
model_pt.transformer.embeddings.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
|
||||
grad_dict['transformer.embeddings.word_embeddings.weight'],
|
||||
rtol=rtol, atol=atol * 5
|
||||
)
|
||||
if has_pos_emb:
|
||||
assert torch.allclose(
|
||||
model.transformer.embeddings.position_embeddings.weight.grad,
|
||||
model_pt.transformer.embeddings.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
grad_dict['transformer.embeddings.position_embeddings.weight'],
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
assert torch.allclose(model.transformer.ln_0.weight.grad, model_pt.transformer.ln_0.weight.grad,
|
||||
assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'],
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.ln_0.bias.grad, model_pt.transformer.ln_0.bias.grad,
|
||||
assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'],
|
||||
rtol=rtol, atol=atol)
|
||||
for i in range(num_layers):
|
||||
# if rank == 0: breakpoint()
|
||||
# torch.distributed.barrier()
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mixer.Wqkv.weight.grad,
|
||||
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], 'three o i -> (three o) i'),
|
||||
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mixer.Wqkv.bias.grad,
|
||||
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
'three o -> (three o)'),
|
||||
grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mixer.out_proj.weight.grad,
|
||||
model_pt.transformer.layers[i].mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
|
||||
grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
if rank == 0:
|
||||
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, model_pt.transformer.layers[i].mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad,
|
||||
grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'],
|
||||
rtol=rtol, atol=atol * 5)
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mlp.fc1.weight.grad,
|
||||
model_pt.transformer.layers[i].mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
|
||||
grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mlp.fc1.bias.grad,
|
||||
model_pt.transformer.layers[i].mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
|
||||
grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
assert torch.allclose(
|
||||
model.transformer.layers[i].mlp.fc2.weight.grad,
|
||||
model_pt.transformer.layers[i].mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
|
||||
grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'],
|
||||
rtol=rtol, atol=atol * 10
|
||||
)
|
||||
if rank == 0:
|
||||
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, model_pt.transformer.layers[i].mlp.fc2.bias.grad,
|
||||
rtol=rtol, atol=atol * 5)
|
||||
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad,
|
||||
grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'],
|
||||
rtol=rtol, atol=atol * 5)
|
||||
|
||||
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, model_pt.transformer.layers[i].norm1.weight.grad, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, model_pt.transformer.layers[i].norm1.bias.grad, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, model_pt.transformer.layers[i].norm2.weight.grad, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, model_pt.transformer.layers[i].norm2.bias.grad, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad,
|
||||
grad_dict[f'transformer.layers.{i}.norm1.weight'],
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad,
|
||||
grad_dict[f'transformer.layers.{i}.norm1.bias'],
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad,
|
||||
grad_dict[f'transformer.layers.{i}.norm2.weight'],
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad,
|
||||
grad_dict[f'transformer.layers.{i}.norm2.bias'],
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user