[GPT] Implement Falcon

This commit is contained in:
Tri Dao 2023-07-23 10:29:23 -07:00
parent 684196b8c5
commit d38357dd2f
3 changed files with 495 additions and 0 deletions

122
flash_attn/models/falcon.py Normal file
View File

@ -0,0 +1,122 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, FalconConfig
def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
r'transformer.layers.\1.mixer.Wqkv.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head
for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv",
1 if getattr(falcon_config, "multi_query", False)
else falcon_config.n_head)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1
return GPT2Config(
vocab_size=falcon_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=falcon_config.hidden_size,
n_layer=falcon_config.n_layer,
n_head=falcon_config.n_head,
n_inner=falcon_config.hidden_size * 4,
activation_function="gelu",
resid_pdrop=falcon_config.hidden_dropout,
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
attn_pdrop=falcon_config.attention_dropout,
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
initializer_range=falcon_config.initializer_range,
bos_token_id=falcon_config.bos_token_id,
eos_token_id=falcon_config.eos_token_id,
# These are new arguments not in the original GPT2Config
parallel_block=falcon_config.parallel_attn,
n_head_kv=n_head_kv,
parallel_block_tied_norm=parallel_block_tied_norm,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=False,
tie_word_embeddings=True,
qkv_proj_bias=falcon_config.bias,
out_proj_bias=falcon_config.bias,
mlp_fc1_bias=falcon_config.bias,
mlp_fc2_bias=falcon_config.bias,
lm_head_bias=False,
)

View File

@ -27,6 +27,7 @@ from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.falcon import remap_state_dict_hf_falcon
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
@ -241,6 +242,8 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith('tiiuae/falcon-'):
state_dict = remap_state_dict_hf_falcon(state_dict, config)
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:

370
tests/models/test_falcon.py Normal file
View File

@ -0,0 +1,370 @@
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.falcon import remap_state_dict_hf_falcon, falcon_config_to_gpt2_config
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
def test_falcon_state_dict(model_name):
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
def test_falcon_optimized(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map={"": device}, trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
def test_falcon_parallel_forward(model_name, world_size):
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = False
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
del model
if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
def test_falcon_generation(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map={"": device}, trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print('Without CUDA graph')
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
def test_falcon_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
trust_remote_code=True))
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = False
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
torch.distributed.barrier()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print('Without CUDA graph')
out = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
assert torch.equal(logits_cg, logits)