diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 0cb1b06..e746fc7 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -14,6 +14,8 @@ import torch.nn.functional as F from transformers import GPT2Config +from einops import rearrange + from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.block import Block @@ -338,3 +340,51 @@ def remap_state_dict_gpt2(state_dict, config): state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict + + +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + vocab_size = config.vocab_size + if config.vocab_size % config.pad_vocab_size_multiple != 0: + vocab_size += (config.pad_vocab_size_multiple + - (config.vocab_size % config.pad_vocab_size_multiple)) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + def shard_first_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim:(rank + 1) * dim] + + def shard_last_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[-1] // world_size + state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + + def shard_qkv_headdim(state_dict, key): + x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], + 'three d ... -> (three d) ...') + + shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') + if 'lm_head.weight' in state_dict: + shard_first_dim(state_dict, 'lm_head.weight') + if 'transformer.embeddings.position_embeddings.weight' in state_dict: + shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias') + return state_dict diff --git a/tests/models/test_gpt_parallel.py b/tests/models/test_gpt_parallel.py index aa7de6d..bd91a1f 100644 --- a/tests/models/test_gpt_parallel.py +++ b/tests/models/test_gpt_parallel.py @@ -12,7 +12,7 @@ from transformers import GPT2Config from apex.transformer import parallel_state -from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.utils.distributed import allreduce_sequence_parallel_grad @@ -22,11 +22,11 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) -# @pytest.mark.parametrize('world_size', [1]) +# @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('has_pos_emb', [True, False]) # @pytest.mark.parametrize('has_pos_emb', [True]) @pytest.mark.parametrize('dim', [1024]) -def test_block_parallel(dim, has_pos_emb, world_size, dtype): +def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): head_dim = 64 assert dim % head_dim == 0 num_heads = dim // head_dim @@ -91,45 +91,8 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype): partition_dim = dim // world_size partition_hidden_dim = 4 * dim // world_size with torch.no_grad(): - model.transformer.embeddings.word_embeddings.weight.copy_( - model_pt.transformer.embeddings.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size] - ) - if has_pos_emb: - model.transformer.embeddings.position_embeddings.weight.copy_( - model_pt.transformer.embeddings.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim] - ) - model.transformer.ln_0.weight.copy_(model_pt.transformer.ln_0.weight) - model.transformer.ln_0.bias.copy_(model_pt.transformer.ln_0.bias) - for i in range(num_layers): - model.transformer.layers[i].mixer.Wqkv.weight.copy_( - rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], - 'three o i -> (three o) i') - ) - model.transformer.layers[i].mixer.Wqkv.bias.copy_( - rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], - 'three o -> (three o)') - ) - model.transformer.layers[i].mixer.out_proj.weight.copy_( - model_pt.transformer.layers[i].mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] - ) - if rank == 0: - model.transformer.layers[i].mixer.out_proj.bias.copy_(model_pt.transformer.layers[i].mixer.out_proj.bias) - model.transformer.layers[i].mlp.fc1.weight.copy_( - model_pt.transformer.layers[i].mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] - ) - model.transformer.layers[i].mlp.fc1.bias.copy_( - model_pt.transformer.layers[i].mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] - ) - model.transformer.layers[i].mlp.fc2.weight.copy_( - model_pt.transformer.layers[i].mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim] - ) - if rank == 0: - model.transformer.layers[i].mlp.fc2.bias.copy_(model_pt.transformer.layers[i].mlp.fc2.bias) - model.transformer.layers[i].norm1.weight.copy_(model_pt.transformer.layers[i].norm1.weight) - model.transformer.layers[i].norm1.bias.copy_(model_pt.transformer.layers[i].norm1.bias) - model.transformer.layers[i].norm2.weight.copy_(model_pt.transformer.layers[i].norm2.weight) - model.transformer.layers[i].norm2.bias.copy_(model_pt.transformer.layers[i].norm2.bias) - # Don't need to copy the lm_head weight since it's tied to the word embedding weight + model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank)) + model.tie_weights() with torch.autocast(device_type='cuda', dtype=dtype): out = model(input_ids[:, :-1]).logits @@ -150,62 +113,73 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype): allreduce_sequence_parallel_grad(model, process_group) parallel_state.destroy_model_parallel() + grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()}, + config, world_size, rank) + assert torch.allclose( model.transformer.embeddings.word_embeddings.weight.grad, - model_pt.transformer.embeddings.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size], + grad_dict['transformer.embeddings.word_embeddings.weight'], rtol=rtol, atol=atol * 5 ) if has_pos_emb: assert torch.allclose( model.transformer.embeddings.position_embeddings.weight.grad, - model_pt.transformer.embeddings.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], + grad_dict['transformer.embeddings.position_embeddings.weight'], rtol=rtol, atol=atol ) - assert torch.allclose(model.transformer.ln_0.weight.grad, model_pt.transformer.ln_0.weight.grad, + assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'], rtol=rtol, atol=atol) - assert torch.allclose(model.transformer.ln_0.bias.grad, model_pt.transformer.ln_0.bias.grad, + assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'], rtol=rtol, atol=atol) for i in range(num_layers): - # if rank == 0: breakpoint() - # torch.distributed.barrier() assert torch.allclose( model.transformer.layers[i].mixer.Wqkv.weight.grad, - rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], 'three o i -> (three o) i'), + grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'], rtol=rtol, atol=atol * 10 ) assert torch.allclose( model.transformer.layers[i].mixer.Wqkv.bias.grad, - rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], - 'three o -> (three o)'), + grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'], rtol=rtol, atol=atol * 10 ) assert torch.allclose( model.transformer.layers[i].mixer.out_proj.weight.grad, - model_pt.transformer.layers[i].mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], + grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'], rtol=rtol, atol=atol * 10 ) if rank == 0: - assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, model_pt.transformer.layers[i].mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5) + assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, + grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'], + rtol=rtol, atol=atol * 5) assert torch.allclose( model.transformer.layers[i].mlp.fc1.weight.grad, - model_pt.transformer.layers[i].mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], + grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'], rtol=rtol, atol=atol * 10 ) assert torch.allclose( model.transformer.layers[i].mlp.fc1.bias.grad, - model_pt.transformer.layers[i].mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], + grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'], rtol=rtol, atol=atol * 10 ) assert torch.allclose( model.transformer.layers[i].mlp.fc2.weight.grad, - model_pt.transformer.layers[i].mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim], + grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'], rtol=rtol, atol=atol * 10 ) if rank == 0: - assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, model_pt.transformer.layers[i].mlp.fc2.bias.grad, - rtol=rtol, atol=atol * 5) + assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, + grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'], + rtol=rtol, atol=atol * 5) - assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, model_pt.transformer.layers[i].norm1.weight.grad, rtol=rtol, atol=atol) - assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, model_pt.transformer.layers[i].norm1.bias.grad, rtol=rtol, atol=atol) - assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, model_pt.transformer.layers[i].norm2.weight.grad, rtol=rtol, atol=atol) - assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, model_pt.transformer.layers[i].norm2.bias.grad, rtol=rtol, atol=atol) + assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, + grad_dict[f'transformer.layers.{i}.norm1.weight'], + rtol=rtol, atol=atol) + assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, + grad_dict[f'transformer.layers.{i}.norm1.bias'], + rtol=rtol, atol=atol) + assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, + grad_dict[f'transformer.layers.{i}.norm2.weight'], + rtol=rtol, atol=atol) + assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, + grad_dict[f'transformer.layers.{i}.norm2.bias'], + rtol=rtol, atol=atol)