diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index b85981d..032388b 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module): """ # Instantiate model. model = cls(config, *args, device=device, dtype=dtype, **kwargs) - # If we're going to shard the model, then don't load fp32 weights to GPU. + # Load state_dict in cpu because we already initialized the model in GPU, and we don't + # want extra stuff taking up more GPU memory state_dict = state_dict_from_pretrained( - model_name, device=device if world_size == 1 else None, dtype=dtype + model_name, device='cpu', dtype=dtype ) if model_name.startswith('gpt2'): state_dict = remap_state_dict_gpt2(state_dict, config) @@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module): raise NotImplementedError(f'Model {model_name} not supported') if world_size > 1: state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} load_return = model.load_state_dict(state_dict, strict=strict) logger.info(load_return) return model diff --git a/flash_attn/models/opt.py b/flash_attn/models/opt.py index 88d7c52..79740cd 100644 --- a/flash_attn/models/opt.py +++ b/flash_attn/models/opt.py @@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config): # LayerNorm def key_mapping_ln(key): key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) + # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' + key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key) key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.', r'transformer.layers.\1.norm1.', key) key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.', diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index f264f4c..6b043e2 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -196,7 +196,7 @@ class DecodingCGCache: @torch.inference_mode() def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, - dtype=None): + dtype=None, n_warmups=2): if cache is None: cache = DecodingCGCache() param_example = next(iter(model.parameters())) @@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p if s_type not in cache.callables: seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen) cache.callables[s_type] = capture_graph( - model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool + model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool, + n_warmups=n_warmups ) def dispatch(input_ids, position_ids, seqlen): @@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p return cache -def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None): +def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None, + n_warmups=2): assert max_seqlen >= seqlen_og device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) @@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - for _ in range(2): + for _ in range(n_warmups): logits = model(input_ids, position_ids=position_ids, inference_params=inference_params).logits[:, -1] s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() torch.cuda.current_stream().wait_stream(s) # Captures the graph # To allow capture, automatically sets a side stream as the current stream in the context diff --git a/flash_attn/utils/pretrained.py b/flash_attn/utils/pretrained.py index 4b170a3..c5b7459 100644 --- a/flash_attn/utils/pretrained.py +++ b/flash_attn/utils/pretrained.py @@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files def state_dict_from_pretrained(model_name, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = 'cpu' if dtype not in [torch.float32, None] else device is_sharded = False resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) @@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ) state_dict = {} for sharded_file in resolved_archive_file: - state_dict.update(torch.load(sharded_file, map_location=device)) + state_dict.update(torch.load(sharded_file, map_location=mapped_device)) else: state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) + # Convert dtype before moving to GPU to save memory if dtype is not None: - state_dict = {k: v.to(dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} return state_dict diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index a347387..42bac36 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): @pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]) -# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"]) +# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"]) def test_greedy_decode_opt(model_name): """Check that our implementation of OPT generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to @@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name): input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device) - max_length = 30 + max_length = 60 # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 @@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name): print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') if verbose: print(out_cg.sequences) - print(tokenizer.batch_decode(out.sequences.tolist())) + print(tokenizer.batch_decode(out_cg.sequences.tolist())) del model diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index 5817a91..50130ad 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): assert torch.all(out.sequences == out_hf.sequences) assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() + + parallel_state.destroy_model_parallel()