[OPT] Load fp16 weights on CPU before moving to GPU

This commit is contained in:
Tri Dao 2023-01-22 17:01:32 -08:00
parent 33e0860c9c
commit 78b7a1dc18
6 changed files with 27 additions and 12 deletions

View File

@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
"""
# Instantiate model.
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
# If we're going to shard the model, then don't load fp32 weights to GPU.
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict = state_dict_from_pretrained(
model_name, device=device if world_size == 1 else None, dtype=dtype
model_name, device='cpu', dtype=dtype
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_gpt2(state_dict, config)
@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
load_return = model.load_state_dict(state_dict, strict=strict)
logger.info(load_return)
return model

View File

@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',

View File

@ -196,7 +196,7 @@ class DecodingCGCache:
@torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
dtype=None):
dtype=None, n_warmups=2):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
if s_type not in cache.callables:
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
cache.callables[s_type] = capture_graph(
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
n_warmups=n_warmups
)
def dispatch(input_ids, position_ids, seqlen):
@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
return cache
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
n_warmups=2):
assert max_seqlen >= seqlen_og
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context

View File

@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
is_sharded = False
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=device))
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict

View File

@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
max_length = 60
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
if verbose:
print(out_cg.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
del model

View File

@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
parallel_state.destroy_model_parallel()