diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index dcd3d62..52b546b 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -27,7 +27,7 @@ from flash_attn.modules.mlp import ( ParallelMLP, ) from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.utils.distributed import all_gather_raw, sync_shared_params +from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config @@ -62,7 +62,6 @@ try: except ImportError: FusedDenseSqreluDense = None - logger = logging.getLogger(__name__) @@ -681,41 +680,58 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size assert inner_dim % world_size == 0 + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", n_head) + + embed_dim = config.hidden_size + head_dim = embed_dim // n_head + def shard_first_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size - state_dict[key] = x[rank * dim : (rank + 1) * dim] + state_dict[key] = x[rank * dim: (rank + 1) * dim] - def shard_last_dim(state_dict, key): + def shard_last_dim(state_dict, key, multiple_of=1): if key in state_dict: x = state_dict[key] - dim = x.shape[-1] // world_size - state_dict[key] = x[..., rank * dim : (rank + 1) * dim] + dim_each_rank = [ + get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of) + for local_rank in range(world_size) + ] + beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1)) + state_dict[key] = x[..., beg:end] def shard_gatedmlp_fc1_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size // 2 state_dict[key] = rearrange( - rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], + rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim], "two o ... -> (two o) ...", ) def shard_qkv_headdim(state_dict, key): if key in state_dict: - n_head = config.n_head - n_head_kv = getattr(config, "n_head_kv", n_head) - assert n_head % world_size == 0 and n_head_kv % world_size == 0 + n_head_each_rank = [ + get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size) + ] + n_head_kv_each_rank = [ + get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size) + ] + + beg_n_head = sum(n_head_each_rank[:rank]) + end_n_head = sum(n_head_each_rank[: rank + 1]) + + beg_n_head_kv = sum(n_head_kv_each_rank[:rank]) + end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1]) + if n_head_kv == n_head: x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) - dim = x.shape[1] // world_size state_dict[key] = rearrange( - x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." + x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..." ) else: - n_head_per_rank = n_head // world_size - n_head_kv_per_rank = n_head_kv // world_size x = rearrange( state_dict[key], "(nheadqkv headdim) ... -> nheadqkv headdim ...", @@ -724,19 +740,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): state_dict[key] = rearrange( torch.cat( [ - x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank], - x[ - n_head - + rank * n_head_kv_per_rank : n_head - + (rank + 1) * n_head_kv_per_rank - ], - x[ - n_head - + n_head_kv - + rank * n_head_kv_per_rank : n_head - + n_head_kv - + (rank + 1) * n_head_kv_per_rank - ], + x[beg_n_head:end_n_head], + x[n_head + beg_n_head_kv: n_head + end_n_head_kv], + x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv], ], dim=0, ), @@ -751,7 +757,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): for i in range(config.num_hidden_layers): shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") - shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") + shard_last_dim( + state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim + ) if rank != 0: state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) if config.activation_function in ["glu", "swiglu", "geglu"]: @@ -816,7 +824,7 @@ def combine_state_dicts_tp(state_dicts, config): torch.cat([x[:n_head_per_rank] for x in xs], dim=0), torch.cat( [ - x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank] + x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank] for x in xs ], dim=0, @@ -922,6 +930,7 @@ def remap_state_dict_megatron(state_dict, config): return key state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 95bb752..592c4d8 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -5,9 +5,10 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange, repeat +from flash_attn.utils.distributed import get_dim_for_local_rank + try: from flash_attn import ( flash_attn_kvpacked_func, @@ -720,22 +721,21 @@ class ParallelMHA(nn.Module): self.use_flash_attn = use_flash_attn self.checkpointing = checkpointing self.process_group = process_group - self.world_size = process_group.size() if process_group is not None else 1 + self.world_size = process_group.size() + self.local_rank = torch.distributed.get_rank(process_group) self.num_heads = num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads - self.num_heads_per_rank = num_heads // self.world_size - self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" - assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - assert ( - self.num_heads_kv % self.world_size == 0 - ), "num_heads_kv must be divisible by world_size" + + self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) + self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) - kv_dim = 2 * self.head_dim * self.num_heads_kv if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary_emb is not installed" @@ -755,6 +755,7 @@ class ParallelMHA(nn.Module): process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, + multiple_of=self.head_dim * 3, **factory_kwargs, ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention @@ -771,6 +772,7 @@ class ParallelMHA(nn.Module): process_group, bias=out_proj_bias, sequence_parallel=sequence_parallel, + multiple_of=self.head_dim, **factory_kwargs, ) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 3353767..fa3ab64 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -226,7 +226,7 @@ class RowParallelLinear(nn.Linear): local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) # Only rank 0 will have bias super().__init__( - in_features // world_size, + local_multiple * multiple_of, out_features, bias=bias and rank == 0, device=device, diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 09578e3..1ed944f 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -125,3 +125,15 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc torch.distributed.all_reduce(coalesced, group=process_group) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) + + +def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: + """Get the dim for the local rank derived from splitting dim on world_size processes. + + The split may not be even across the world_size processes. + """ + multiple = dim // multiple_of + div = multiple // world_size + mod = multiple % world_size + local_multiple = div + int(local_rank < mod) + return local_multiple * multiple_of diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 6a45bc9..09b00d1 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -16,7 +16,7 @@ import pytest from einops import rearrange -from transformers import LlamaTokenizer +from transformers import LlamaTokenizer, LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp @@ -255,7 +255,6 @@ def test_llama_generation(model_name, checkpoint_format): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device) del model_ref - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format ) @@ -297,8 +296,8 @@ def test_llama_generation(model_name, checkpoint_format): hf_error = (logits_hf - logits_ref).abs().max().item() print(f'HF fp16 logits max diff: {hf_error}') - print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') - print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}') assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error @@ -410,7 +409,101 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): hf_error = (logits_hf - logits_ref).abs().max().item() print(f'HF fp16 logits max diff: {hf_error}') - print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') assert (logits - logits_ref).abs().max().item() < 2 * hf_error - print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') + print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}') assert torch.equal(logits_cg, logits) + + +@torch.no_grad() +@pytest.mark.parametrize('world_size', [2]) +def test_llama_parallel_uneven_num_heads(world_size): + from apex.transformer import parallel_state + + checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama' + num_attention_heads = world_size + 1 + model_name = f'teeny-{num_attention_heads}-heads' + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + dtype = torch.float16 + llama_config = LlamaConfig( + hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256 + intermediate_size=256 * num_attention_heads * 4, + num_hidden_layers=4, + num_attention_heads=num_attention_heads, + initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test. + ) + config = llama_config_to_gpt2_config(llama_config) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = False # We don't have fused GatedMLP yet + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) + input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, + device=device) + + # Create a shared test model. + if rank == 0: + LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf") + torch.distributed.barrier() + + # Run the standard forward pass test. + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="hf" + ) + model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) + model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) + model.eval() + + # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs. + out = model.transformer(input_ids) + out, _ = all_gather_raw(out, process_group=process_group) + out = rearrange(out, "(b s) d -> b s d", b=batch_size) + logits = model(input_ids).logits + logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) + logits, _ = all_gather_raw(logits, process_group) + logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size) + + if rank == 0: + model_ref = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + ) + model_ref.eval() + out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) + logits_ref = model_ref(input_ids).logits.to(device=device) + del model_ref + + model_hf = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + ) + model_hf.eval() + out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) + logits_hf = model_hf(input_ids).logits.to(device=device) + del model_hf + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() + + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') + assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() + + import shutil + shutil.rmtree(checkpoint_path / f'{model_name}-hf')