From 184b992dcb2a0890adaa19eb9b541c3e4f9d2a08 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 28 Jul 2023 15:52:48 -1000 Subject: [PATCH] [GPT] Implement parallel LLaMa --- flash_attn/models/gpt.py | 17 +++- tests/models/test_falcon.py | 2 - tests/models/test_llama.py | 190 +++++++++++++++++++++++++++++------- 3 files changed, 171 insertions(+), 38 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 745b5d5..56d76ff 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -527,6 +527,15 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): dim = x.shape[-1] // world_size state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + def shard_gatedmlp_fc1_dim(state_dict, key): + if key in state_dict: + x = state_dict[key] + dim = x.shape[0] // world_size // 2 + state_dict[key] = rearrange( + rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim:(rank + 1) * dim], + "two o ... -> (two o) ..." + ) + def shard_qkv_headdim(state_dict, key): if key in state_dict: n_head = config.n_head @@ -559,8 +568,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') if rank != 0: state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None) - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + if config.activation_function in ["glu", "swiglu", "geglu"]: + shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + else: + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') if rank != 0: state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None) diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index edcd93b..66b97b9 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -300,8 +300,6 @@ def test_falcon_parallel_generation(model_name, world_size): input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device) - torch.distributed.barrier() - # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 36807a8..739e9f4 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -13,12 +13,15 @@ current_dir = Path(__file__).parent.absolute() import torch import pytest +from einops import rearrange + from transformers import LlamaConfig, LlamaTokenizer from transformers.models.llama.modeling_llama import LlamaForCausalLM -from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp +from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint +from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import update_graph_cache @@ -38,6 +41,7 @@ def test_llama_state_dict(model_name): @pytest.mark.parametrize('model_name', ["7B", "13B"]) +# @pytest.mark.parametrize('model_name', ["7B"]) def test_llama_optimized(model_name): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -59,7 +63,7 @@ def test_llama_optimized(model_name): pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) model = GPTLMHeadModel(config, device=device, dtype=dtype) - model.load_state_dict(pretrained_state_dict, strict=False) + model.load_state_dict(pretrained_state_dict) model.eval() torch.manual_seed(0) @@ -86,8 +90,9 @@ def test_llama_optimized(model_name): model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map={"": device}) model_hf.eval() - out_hf = model_hf.model(input_ids).last_hidden_state - logits_hf = model_hf(input_ids).logits + with torch.no_grad(): + out_hf = model_hf.model(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits del model_hf print(f'Output max diff: {(out - out_ref).abs().max().item()}') @@ -104,7 +109,6 @@ def test_llama_optimized(model_name): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" -@pytest.mark.skip(reason="Tensor Parallel is not implemented for GatedMLP yet") @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('model_name', ["13B"]) def test_llama_parallel(model_name, world_size): @@ -118,7 +122,6 @@ def test_llama_parallel(model_name, world_size): current_dir.parent.parent / 'checkpoints')) / 'llama' dtype = torch.float16 - device = 'cuda' config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) config.use_flash_attn = True config.fused_bias_fc = True @@ -139,8 +142,7 @@ def test_llama_parallel(model_name, world_size): pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) - model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank), - strict=False) + model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() torch.manual_seed(0) @@ -151,39 +153,49 @@ def test_llama_parallel(model_name, world_size): device=device) with torch.no_grad(): out = model.transformer(input_ids) + out, _ = all_gather_raw(out, process_group=process_group) + out = rearrange(out, "(b s) d -> b s d", b=batch_size) logits = model(input_ids).logits + logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) + logits, _ = all_gather_raw(logits, process_group) + logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size) del model - # Without device_map, the model is loaded on the CPU, which is very slow - model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - device_map='auto') - model_ref.eval() - with torch.no_grad(): - out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) - logits_ref = model_ref(input_ids).logits.to(device=device) - del model_ref + if rank == 0: + # Without device_map, the model is loaded on the CPU, which is very slow + model_ref = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + ) + model_ref.eval() + with torch.no_grad(): + out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) + logits_ref = model_ref(input_ids).logits.to(device=device) + del model_ref - model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - torch_dtype=dtype, device_map="auto") - model_hf.eval() - out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) - logits_hf = model_hf(input_ids).logits.to(device=device) - del model_hf + model_hf = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + ) + model_hf.eval() + with torch.no_grad(): + out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) + logits_hf = model_hf(input_ids).logits.to(device=device) + del model_hf - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') - assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') - print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') - assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') + assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() -@pytest.mark.parametrize('model_name', ["7B", "13B"]) +# @pytest.mark.parametrize('model_name', ["7B", "13B"]) +@pytest.mark.parametrize('model_name', ["7B"]) def test_llama_generation(model_name): checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama' @@ -231,7 +243,7 @@ def test_llama_generation(model_name): pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) model = GPTLMHeadModel(config, device=device, dtype=dtype) - model.load_state_dict(pretrained_state_dict, strict=False) + model.load_state_dict(pretrained_state_dict) model.eval() print('Without CUDA graph') @@ -274,3 +286,113 @@ def test_llama_generation(model_name): assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) + + +# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" +@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('model_name', ["13B"]) +def test_llama_parallel_generation(model_name, world_size): + """Check that our implementation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + from apex.transformer import parallel_state + + checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', + current_dir.parent.parent / 'checkpoints')) / 'llama' + + dtype = torch.float16 + config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) + config.use_flash_attn = False + config.fused_bias_fc = True + config.fused_mlp = False # We don't have fused GatedMLP yet + config.fused_dropout_add_ln = False + config.residual_in_fp32 = True + config.pad_vocab_size_multiple = 8 * world_size + config.sequence_parallel = False # Need to set this to False for generation + + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + torch.manual_seed(0) + batch_size = 1 + seqlen = 100 + max_length = 150 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, + device=device) + + # Need this, otherwise when we capture the graph the process for GPU 1 would run on both + # GPU0 and GPU1 and things would hang + torch.cuda.set_device(device) + + ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) + pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] + pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) + + model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) + model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) + model.eval() + + print('Without CUDA graph') + out = model.generate( + input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=True, + # teacher_outputs=out_hf.sequences, + return_dict_in_generate=True, output_scores=True, timing=True + ) + + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print('With CUDA graph') + out_cg = model.generate( + input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True, + # teacher_outputs=out_hf.sequences, + return_dict_in_generate=True, output_scores=True, timing=True + ) + del model + parallel_state.destroy_model_parallel() + + if rank == 0: + # Without device_map, the model is loaded on the CPU, which is very slow + model_hf = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + ) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + with torch.inference_mode(): + out_hf = model_hf.generate( + input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, + output_scores=True + ) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + del model_hf + + model_ref = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + ) + model_ref.eval() + with torch.inference_mode(): + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] + del model_ref + logits_hf = torch.stack(out_hf.scores, dim=1) + + logits = torch.stack(out.scores, dim=1) + logits_cg = torch.stack(out_cg.scores, dim=1) + + hf_error = (logits_hf - logits_ref).abs().max().item() + print(f'HF fp16 logits max diff: {hf_error}') + print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') + assert (logits - logits_ref).abs().max().item() < 2 * hf_error + print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') + assert torch.equal(logits_cg, logits)