diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 032388b..566697b 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -18,12 +18,13 @@ from einops import rearrange from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP -from flash_attn.modules.block import Block +from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import GenerationMixin -from flash_attn.models.opt import remap_state_dict_opt +from flash_attn.models.opt import remap_state_dict_hf_opt +from flash_attn.models.gptj import remap_state_dict_hf_gptj try: from flash_attn.ops.fused_dense import ColumnParallelLinear @@ -36,9 +37,10 @@ except ImportError: dropout_add_layer_norm = None try: - from flash_attn.ops.triton.mlp import FusedDenseSqreluDense + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd except ImportError: FusedDenseSqreluDense = None + sqrelu_fwd = None logger = logging.getLogger(__name__) @@ -54,8 +56,11 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt dwconv = getattr(config, 'attn_dwconv', False) if dwconv: assert process_group is None, 'TensorParallel MHA does not support dwconv yet' + qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) + out_proj_bias = getattr(config, 'out_proj_bias', True) rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) - rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0) + rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) + rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) use_flash_attn = getattr(config, 'use_flash_attn', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False) if not fused_bias_fc: @@ -66,9 +71,12 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': getattr(config, 'sequence_parallel', True)} if process_group is not None else {}) - mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, + mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, + qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, + dropout=config.attn_pdrop, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_interleaved=rotary_emb_interleaved, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, **factory_kwargs) return mixer_cls @@ -88,8 +96,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if process_group is not None: assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' if not fused_mlp and not fused_dense_sqrelu_dense: + assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] if config.activation_function == 'relu': activation = partial(F.relu, inplace=True) + elif config.activation_function == 'sqrelu': + assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented' + activation = sqrelu_fwd else: approximate = ('tanh' if config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') @@ -132,12 +144,27 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype= residual_in_fp32 = getattr(config, 'residual_in_fp32', False) resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop prenorm = getattr(config, 'prenorm', True) - block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, - fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), - residual_in_fp32=residual_in_fp32, - sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None) + parallel_block = getattr(config, 'parallel_block', False) + if not parallel_block: + block = Block( + config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None + ) + else: + assert prenorm + block = ParallelBlock( + config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, + tied_norm=getattr(config, 'parallel_block_tied_norm', False), + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None + ) block.layer_idx = layer_idx return block @@ -172,9 +199,12 @@ class GPTPreTrainedModel(nn.Module): model_name, device='cpu', dtype=dtype ) if model_name.startswith('gpt2'): - state_dict = remap_state_dict_gpt2(state_dict, config) + state_dict = remap_state_dict_hf_gpt2(state_dict, config) elif model_name.startswith('facebook/opt'): - state_dict = remap_state_dict_opt(state_dict, config) + state_dict = remap_state_dict_hf_opt(state_dict, config) + elif model_name.startswith('EleutherAI/gpt-j-'): + state_dict = remap_state_dict_hf_gptj(state_dict, config) + strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint else: raise NotImplementedError(f'Model {model_name} not supported') if world_size > 1: @@ -223,6 +253,8 @@ class GPTModel(GPTPreTrainedModel): # These 2 options are for OPT-350m self.prenorm = getattr(config, 'prenorm', True) word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + # For GPT-J, GPT-NeoX + self.parallel_block = getattr(config, 'parallel_block', False) if process_group is None: self.embeddings = GPT2Embeddings( @@ -276,6 +308,8 @@ class GPTModel(GPTPreTrainedModel): embedding_kwargs = ({'combine_batch_seqlen_dim': True} if self.process_group is not None and self.sequence_parallel else {}) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) + if self.parallel_block: + hidden_states2 = None residual = None mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None and self.sequence_parallel else {}) @@ -283,15 +317,27 @@ class GPTModel(GPTPreTrainedModel): mixer_kwargs['inference_params'] = inference_params for layer in self.layers: if self.prenorm: - hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) + if not self.parallel_block: + hidden_states, residual = layer(hidden_states, residual, + mixer_kwargs=mixer_kwargs) + else: + hidden_states, hidden_states2, residual = layer( + hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs + ) else: hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) if self.prenorm: if not self.fused_dropout_add_ln: dropped = self.drop_f(hidden_states) - residual = (dropped + residual) if residual is not None else dropped + if not self.parallel_block: + residual = (dropped + residual) if residual is not None else dropped + else: + dropped2 = self.drop_f(hidden_states2) + residual = ((residual + dropped + dropped2) + if residual is not None else dropped + dropped2) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: + assert not self.parallel_block # Set prenorm=False here since we don't need the residual hidden_states = dropout_add_layer_norm( hidden_states, residual, self.ln_f.weight, self.ln_f.bias, @@ -308,6 +354,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): super().__init__(config) self.process_group = process_group self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) + self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True) pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) @@ -319,12 +366,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): else: self.project_out = None if process_group is None: - self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs) + self.lm_head = nn.Linear(embed_dim, vocab_size, bias=not self.tie_word_embeddings, + **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError('fused_dense_lib is not installed') self.lm_head = ColumnParallelLinear( - embed_dim, vocab_size, process_group, bias=False, + embed_dim, vocab_size, process_group, bias=not self.tie_word_embeddings, sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs ) # Initialize weights and apply final processing @@ -333,7 +381,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): self.tie_weights() def tie_weights(self): - self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight + if self.tie_word_embeddings: + self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight if self.process_group is not None: sync_shared_params(self, self.process_group) @@ -381,7 +430,95 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): return super().load_state_dict(state_dict, strict=strict) -def remap_state_dict_gpt2(state_dict, config): +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + def shard_first_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim:(rank + 1) * dim] + + def shard_last_dim(state_dict, key): + x = state_dict[key] + dim = x.shape[-1] // world_size + state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + + def shard_qkv_headdim(state_dict, key): + x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], + 'three d ... -> (three d) ...') + + shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') + if 'lm_head.weight' in state_dict: + shard_first_dim(state_dict, 'lm_head.weight') + if 'transformer.embeddings.position_embeddings.weight' in state_dict: + shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias') + return state_dict + + +def combine_state_dicts_tp(state_dicts, config): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + world_size = len(state_dicts) + keys = state_dicts[0].keys() + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + # The word embeddings from Megatron are weird, for each shard only the first + # vocab_size // world_size coordinates are nonzero. + def combine_word_embeddings(state_dicts, state_dict, key): + assert all(s[key].shape[0] == vocab_size for s in state_dicts) + state_dict[key] = torch.cat([s[key][:vocab_size // world_size] for s in state_dicts], dim=0) + + def combine_dim(state_dicts, state_dict, key, dim=-1): + state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) + + def combine_qkv_headdim(state_dicts, state_dict, key): + xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] + state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') + + state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace + combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight') + if 'lm_head.weight' in state_dict: + combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight') + if 'transformer.embeddings.position_embeddings.weight' in state_dict: + combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1) + for i in range(config.num_hidden_layers): + combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') + combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1) + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight', 0) + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0) + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1) + return state_dict + + +def remap_state_dict_hf_gpt2(state_dict, config): # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) @@ -430,47 +567,67 @@ def remap_state_dict_gpt2(state_dict, config): return state_dict -def shard_state_dict_tp(state_dict, config, world_size, rank): - """Convert the state_dict of a standard GPT model to the state_dict of a GPT model - with tensor parallel. - """ +def remap_state_dict_megatron(state_dict, config): + def key_mapping_transformer(key): + key = re.sub(r'^language_model.encoder.', 'transformer.', key) + key = re.sub(r'^language_model.', 'transformer.', key) + return key + state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) - assert vocab_size % world_size == 0 - assert config.hidden_size % world_size == 0 - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size - assert inner_dim % world_size == 0 + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] - def shard_first_dim(state_dict, key): - x = state_dict[key] - dim = x.shape[0] // world_size - state_dict[key] = x[rank * dim:(rank + 1) * dim] + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key) + key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)', + r'transformer.layers.\1.norm1.\2', key) + key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)', + r'transformer.layers.\1.norm2.\2', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def shard_last_dim(state_dict, key): - x = state_dict[key] - dim = x.shape[-1] // world_size - state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)', + r'transformer.layers.\1.mlp.fc1.\2', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)', + r'transformer.layers.\1.mlp.fc2.\2', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - def shard_qkv_headdim(state_dict, key): - x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) - dim = x.shape[1] // world_size - state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], - 'three d ... -> (three d) ...') + # Attention + def key_mapping_attn(key): + key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq', + r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key) + key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)', + r'transformer.layers.\1.mixer.Wqkv.\2', key) + key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)', + r'transformer.layers.\1.mixer.out_proj.\2', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + for d in range(config.num_hidden_layers): + Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight') + state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange( + Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', + three=3, headdim=headdim + ) + bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias') + state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange( + bqkv, '(nheads three headdim) -> (three nheads headdim)', + three=3, headdim=headdim + ) - shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') - if 'lm_head.weight' in state_dict: - shard_first_dim(state_dict, 'lm_head.weight') - if 'transformer.embeddings.position_embeddings.weight' in state_dict: - shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') - for i in range(config.num_hidden_layers): - shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') - shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') - shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') - if rank != 0: - state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias') - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') - shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') - if rank != 0: - state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias') return state_dict diff --git a/flash_attn/models/gptj.py b/flash_attn/models/gptj.py new file mode 100644 index 0000000..ff85998 --- /dev/null +++ b/flash_attn/models/gptj.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from transformers import GPT2Config, GPTJConfig + + +def remap_state_dict_hf_gptj(state_dict, config): + def key_mapping_layers(key): + return re.sub(r'^transformer.h.', 'transformer.layers.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('lm_head.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key) + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight') + Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight') + Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat( + [Wq, Wk, Wv], dim=0 + ) + # We don't store these biases + state_dict.pop(f'transformer.layers.{l}.attn.bias') + state_dict.pop(f'transformer.layers.{l}.attn.masked_bias') + def key_mapping_attn(key): + return re.sub(r'^transformer.layers.(\d+).attn.out_proj.', + r'transformer.layers.\1.mixer.out_proj.', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: + headdim = gptj_config.n_embd // gptj_config.n_head + return GPT2Config( + vocab_size=gptj_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gptj_config.n_embd, + n_layer=gptj_config.n_layer, + n_head=gptj_config.n_head, + n_inner=gptj_config.n_inner, + activation_function=gptj_config.activation_function, + resid_pdrop=gptj_config.resid_pdrop, + embd_pdrop=gptj_config.embd_pdrop, + attn_pdrop=gptj_config.attn_pdrop, + layer_norm_epsilon=gptj_config.layer_norm_epsilon, + initializer_range=gptj_config.initializer_range, + bos_token_id=gptj_config.bos_token_id, + eos_token_id=gptj_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=True, + parallel_block_tied_norm=True, + rotary_emb_fraction=gptj_config.rotary_dim / headdim, + rotary_emb_interleaved=True, + tie_word_embeddings=False, + qkv_proj_bias=False, + out_proj_bias=False, + ) diff --git a/flash_attn/models/opt.py b/flash_attn/models/opt.py index 79740cd..93bf5ed 100644 --- a/flash_attn/models/opt.py +++ b/flash_attn/models/opt.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from transformers import GPT2Config, OPTConfig -def remap_state_dict_opt(state_dict, config): +def remap_state_dict_hf_opt(state_dict, config): def key_mapping_model(key): key = re.sub(r'^model.decoder.', 'transformer.', key) # The OPT-350m model uses '^decoder' instead of '^model.decoder' diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index 763c7be..f1c4fcd 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -190,3 +190,93 @@ class Block(nn.Module): rowscale=rowscale2, prenorm=False ) return hidden_states + + +class ParallelBlock(nn.Module): + """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, + and PaLM. + """ + + def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0., + tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False, + sequence_parallel=False, mark_shared_params=False): + """ + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA / MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both + the hidden_states (output1 of the MHA / MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.tied_norm = tied_norm + self.fused_dropout_add_ln = fused_dropout_add_ln + assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet' + self.residual_in_fp32 = residual_in_fp32 + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + self.dropout2 = dropout_cls(resid_dropout2) + if not self.tied_norm: + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' + assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._shared_params = True + + def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, + residual: Optional[Tensor] = None, mixer_kwargs=None): + r"""Pass the input through the encoder layer. + + Args: + hidden_states1: the output of the previous attention (mixer) or embedding layer. + hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). + residual. + """ + dropped1 = self.dropout1(hidden_states1) + # For the very 1st block, we only want 1 dropout, not two different dropouts + if hidden_states2 is not None: + dropped2 = self.dropout2(hidden_states2) + residual = ((residual + dropped1 + dropped2) + if residual is not None else dropped1 + dropped2) + else: + residual = (residual + dropped1) if residual is not None else dropped1 + hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if not self.tied_norm else hidden_states1) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + if mixer_kwargs is None: + mixer_kwargs = {} + hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) + hidden_states2 = self.mlp(hidden_states2) + return hidden_states1, hidden_states2, residual diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4eb5aaf..4434b2f 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -347,9 +347,10 @@ class MHA(nn.Module): """Multi-head self-attention and cross-attention """ - def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, - softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, - rotary_emb_scale_base=0, + def __init__(self, embed_dim, num_heads, cross_attn=False, + qkv_proj_bias=True, out_proj_bias=True, + dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, + rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False, return_residual=False, checkpointing=False, device=None, dtype=None) -> None: """ @@ -377,7 +378,7 @@ class MHA(nn.Module): assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert RotaryEmbedding is not None, 'rotary_emb is not installed' self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, - device=device) + interleaved=rotary_emb_interleaved, device=device) if fused_bias_fc and FusedDense is None: raise ImportError('fused_dense is not installed') @@ -388,29 +389,32 @@ class MHA(nn.Module): inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention if not self.cross_attn: if not self.return_residual: - self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias, + **factory_kwargs) else: - self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias, + **factory_kwargs) if self.dwconv: self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, groups=3 * embed_dim) else: - self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) if not self.return_residual: - self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias, + **factory_kwargs) else: - self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias, + **factory_kwargs) if self.dwconv: self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, - groups=embed_dim) + groups=embed_dim) self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, - groups=2 * embed_dim) + groups=2 * embed_dim) self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - # output projection always have the bias (for now) - self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) + self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) @@ -526,9 +530,10 @@ class ParallelMHA(nn.Module): """Multi-head self-attention and cross-attention """ - def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0, - softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, - rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False, + def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True, + dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, + rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, + use_flash_attn=False, checkpointing=False, sequence_parallel=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() @@ -546,11 +551,12 @@ class ParallelMHA(nn.Module): if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, 'rotary_emb is not installed' self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, - device=device) + interleaved=rotary_emb_interleaved, device=device) if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError('fused_dense is not installed') - self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias, + self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, + bias=qkv_proj_bias, sequence_parallel=sequence_parallel, **factory_kwargs) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention @@ -558,8 +564,8 @@ class ParallelMHA(nn.Module): attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - # output projection always have the bias (for now) self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, + bias=out_proj_bias, sequence_parallel=sequence_parallel, **factory_kwargs) def forward(self, x, seqlen=None, inference_params=None, **kwargs): diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 6b043e2..233d079 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, - eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, - cg=False, timing=False): + eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1, + fused_ft_kernel=False, cg=False, timing=False): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, @@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 if cg: assert fused_ft_kernel if not hasattr(model, '_decoding_cache'): @@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits) - next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + if teacher_outputs is None or teacher_output_len <= seqlen_og: + next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + next_token = teacher_outputs[:, seqlen_og] sequences = [next_token] inference_params.sequence_len_offset = seqlen_og while True: @@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits) - next_token = sample(logits, top_k=top_k, temperature=temperature) + if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1: + next_token = sample(logits, top_k=top_k, temperature=temperature) + else: + next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] sequences.append(next_token) inference_params.sequence_len_offset += 1 if eos_token_id is not None and (next_token == eos_token_id).all(): diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 98a9960..64c6e45 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -7,7 +7,7 @@ from transformers import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from flash_attn.models.gpt import GPTLMHeadModel -from flash_attn.models.gpt import remap_state_dict_gpt2 +from flash_attn.models.gpt import remap_state_dict_hf_gpt2 from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained # @pytest.mark.parametrize('model_name', ["gpt2"]) def test_gpt2_state_dict(model_name): config = GPT2Config.from_pretrained(model_name) - pretrained_state_dict = remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config) + pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config) state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 42bac36..dbbe05d 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead from transformers.models.opt.modeling_opt import OPTForCausalLM from flash_attn.models.gpt import GPTLMHeadModel -from flash_attn.models.gpt import remap_state_dict_gpt2 -from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config +from flash_attn.models.gpt import remap_state_dict_hf_gpt2 +from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.generation import update_graph_cache diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index 50130ad..7525644 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from flash_attn.models.gpt import GPTLMHeadModel -from flash_attn.models.gpt import remap_state_dict_gpt2 +from flash_attn.models.gpt import remap_state_dict_hf_gpt2 from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.distributed import all_gather_raw diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py new file mode 100644 index 0000000..8e9b6df --- /dev/null +++ b/tests/models/test_gptj.py @@ -0,0 +1,80 @@ +import re + +import torch +import pytest + +from transformers import GPTJConfig +from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config +from flash_attn.utils.pretrained import state_dict_from_pretrained + + +@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) +def test_gptj_state_dict(model_name): + config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) + pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config) + model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow + state_dict = model.state_dict() + rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq' + for l in range(config.n_layer)} + assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys + for k in state_dict.keys() - rotary_inv_freq_keys: + assert state_dict[k].shape == pretrained_state_dict[k].shape + + +@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) +def test_gptj_optimized(model_name): + """Check that our implementation of GPT-J (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = 'cuda' + config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) + config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = False # We don't support parallel block yet + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') + input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, + device='cuda') + with torch.no_grad(): + out = model.transformer(input_ids) + logits = model(input_ids).logits + del model + + model_ref = GPTJForCausalLM.from_pretrained(model_name).to(device=device) + model_ref.eval() + with torch.no_grad(): + out_ref = model_ref.transformer(input_ids).last_hidden_state + logits_ref = model_ref(input_ids).logits + del model_ref + + model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) + model_hf.eval() + out_hf = model_hf.transformer(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits + del model_hf + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() + + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') + assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index 04ebfe5..82099c9 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -7,7 +7,7 @@ from transformers import OPTConfig from transformers.models.opt.modeling_opt import OPTForCausalLM from flash_attn.models.gpt import GPTLMHeadModel -from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config +from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) def test_opt_state_dict(model_name): config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) - pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config) + pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config) state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys()