[OPT] Load fp16 weights on CPU before moving to GPU
This commit is contained in:
parent
33e0860c9c
commit
78b7a1dc18
@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
||||
# If we're going to shard the model, then don't load fp32 weights to GPU.
|
||||
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
|
||||
# want extra stuff taking up more GPU memory
|
||||
state_dict = state_dict_from_pretrained(
|
||||
model_name, device=device if world_size == 1 else None, dtype=dtype
|
||||
model_name, device='cpu', dtype=dtype
|
||||
)
|
||||
if model_name.startswith('gpt2'):
|
||||
state_dict = remap_state_dict_gpt2(state_dict, config)
|
||||
@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
|
||||
raise NotImplementedError(f'Model {model_name} not supported')
|
||||
if world_size > 1:
|
||||
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
load_return = model.load_state_dict(state_dict, strict=strict)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
|
||||
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
|
||||
|
||||
@ -196,7 +196,7 @@ class DecodingCGCache:
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
||||
dtype=None):
|
||||
dtype=None, n_warmups=2):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
if s_type not in cache.callables:
|
||||
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
|
||||
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
|
||||
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
|
||||
n_warmups=2):
|
||||
assert max_seqlen >= seqlen_og
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(2):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
# that's how I interpret the documentation). I'm not sure if this is required.
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
|
||||
@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
# If not fp32, then we don't want to load directly to the GPU
|
||||
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
|
||||
is_sharded = False
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
)
|
||||
state_dict = {}
|
||||
for sharded_file in resolved_archive_file:
|
||||
state_dict.update(torch.load(sharded_file, map_location=device))
|
||||
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
|
||||
else:
|
||||
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||
# Convert dtype before moving to GPU to save memory
|
||||
if dtype is not None:
|
||||
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
|
||||
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
|
||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
|
||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
|
||||
def test_greedy_decode_opt(model_name):
|
||||
"""Check that our implementation of OPT generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
|
||||
|
||||
input_ids = tokenizer("Hello, my dog is cute and",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
max_length = 60
|
||||
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
|
||||
@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
if verbose:
|
||||
print(out_cg.sequences)
|
||||
print(tokenizer.batch_decode(out.sequences.tolist()))
|
||||
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
|
||||
|
||||
del model
|
||||
|
||||
|
||||
@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user