* uneql rank. * trim. * enable passing in number of heads for each rank. * simplify. * simplify. * cleanup. * fix col parallel. * fix bug with row parallel. * fit out proj. * refac. * fix sharding logic. * refac sharding. * refac. * support multiple of. * make fn reuseable. * fix bug in dimensions. * scaffold. * test uneven heads. * fix test by adding barrier. * refac. * reuse code. * clean up.
This commit is contained in:
parent
ada4710d70
commit
bb4cded17b
@ -27,7 +27,7 @@ from flash_attn.modules.mlp import (
|
||||
ParallelMLP,
|
||||
)
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params
|
||||
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import GPT2Config
|
||||
@ -62,7 +62,6 @@ try:
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -681,41 +680,58 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", n_head)
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
head_dim = embed_dim // n_head
|
||||
|
||||
def shard_first_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size
|
||||
state_dict[key] = x[rank * dim : (rank + 1) * dim]
|
||||
state_dict[key] = x[rank * dim: (rank + 1) * dim]
|
||||
|
||||
def shard_last_dim(state_dict, key):
|
||||
def shard_last_dim(state_dict, key, multiple_of=1):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[-1] // world_size
|
||||
state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
|
||||
dim_each_rank = [
|
||||
get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
|
||||
state_dict[key] = x[..., beg:end]
|
||||
|
||||
def shard_gatedmlp_fc1_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size // 2
|
||||
state_dict[key] = rearrange(
|
||||
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
|
||||
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
|
||||
"two o ... -> (two o) ...",
|
||||
)
|
||||
|
||||
def shard_qkv_headdim(state_dict, key):
|
||||
if key in state_dict:
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
n_head_each_rank = [
|
||||
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
|
||||
]
|
||||
n_head_kv_each_rank = [
|
||||
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
|
||||
]
|
||||
|
||||
beg_n_head = sum(n_head_each_rank[:rank])
|
||||
end_n_head = sum(n_head_each_rank[: rank + 1])
|
||||
|
||||
beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
|
||||
end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])
|
||||
|
||||
if n_head_kv == n_head:
|
||||
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
|
||||
dim = x.shape[1] // world_size
|
||||
state_dict[key] = rearrange(
|
||||
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
|
||||
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
|
||||
)
|
||||
else:
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
x = rearrange(
|
||||
state_dict[key],
|
||||
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
|
||||
@ -724,19 +740,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
state_dict[key] = rearrange(
|
||||
torch.cat(
|
||||
[
|
||||
x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank],
|
||||
x[
|
||||
n_head
|
||||
+ rank * n_head_kv_per_rank : n_head
|
||||
+ (rank + 1) * n_head_kv_per_rank
|
||||
],
|
||||
x[
|
||||
n_head
|
||||
+ n_head_kv
|
||||
+ rank * n_head_kv_per_rank : n_head
|
||||
+ n_head_kv
|
||||
+ (rank + 1) * n_head_kv_per_rank
|
||||
],
|
||||
x[beg_n_head:end_n_head],
|
||||
x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
|
||||
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
@ -751,7 +757,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
for i in range(config.num_hidden_layers):
|
||||
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
||||
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
||||
shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
|
||||
shard_last_dim(
|
||||
state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
|
||||
)
|
||||
if rank != 0:
|
||||
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
|
||||
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
||||
@ -816,7 +824,7 @@ def combine_state_dicts_tp(state_dicts, config):
|
||||
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
|
||||
torch.cat(
|
||||
[
|
||||
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
|
||||
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
|
||||
for x in xs
|
||||
],
|
||||
dim=0,
|
||||
@ -922,6 +930,7 @@ def remap_state_dict_megatron(state_dict, config):
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
|
||||
|
||||
@ -5,9 +5,10 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.distributed import get_dim_for_local_rank
|
||||
|
||||
try:
|
||||
from flash_attn import (
|
||||
flash_attn_kvpacked_func,
|
||||
@ -720,22 +721,21 @@ class ParallelMHA(nn.Module):
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.checkpointing = checkpointing
|
||||
self.process_group = process_group
|
||||
self.world_size = process_group.size() if process_group is not None else 1
|
||||
self.world_size = process_group.size()
|
||||
self.local_rank = torch.distributed.get_rank(process_group)
|
||||
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
self.num_heads_per_rank = num_heads // self.world_size
|
||||
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
|
||||
assert (
|
||||
self.num_heads % self.num_heads_kv == 0
|
||||
), "num_heads must be divisible by num_heads_kv"
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
assert (
|
||||
self.num_heads_kv % self.world_size == 0
|
||||
), "num_heads_kv must be divisible by world_size"
|
||||
|
||||
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
|
||||
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
||||
@ -755,6 +755,7 @@ class ParallelMHA(nn.Module):
|
||||
process_group,
|
||||
bias=qkv_proj_bias,
|
||||
sequence_parallel=sequence_parallel,
|
||||
multiple_of=self.head_dim * 3,
|
||||
**factory_kwargs,
|
||||
)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
@ -771,6 +772,7 @@ class ParallelMHA(nn.Module):
|
||||
process_group,
|
||||
bias=out_proj_bias,
|
||||
sequence_parallel=sequence_parallel,
|
||||
multiple_of=self.head_dim,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -226,7 +226,7 @@ class RowParallelLinear(nn.Linear):
|
||||
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||
# Only rank 0 will have bias
|
||||
super().__init__(
|
||||
in_features // world_size,
|
||||
local_multiple * multiple_of,
|
||||
out_features,
|
||||
bias=bias and rank == 0,
|
||||
device=device,
|
||||
|
||||
@ -125,3 +125,15 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
||||
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
||||
|
||||
The split may not be even across the world_size processes.
|
||||
"""
|
||||
multiple = dim // multiple_of
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
local_multiple = div + int(local_rank < mod)
|
||||
return local_multiple * multiple_of
|
||||
|
||||
@ -16,7 +16,7 @@ import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers import LlamaTokenizer, LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
|
||||
@ -255,7 +255,6 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format
|
||||
)
|
||||
@ -297,8 +296,8 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
|
||||
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
@ -410,7 +409,101 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
def test_llama_parallel_uneven_num_heads(world_size):
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
num_attention_heads = world_size + 1
|
||||
model_name = f'teeny-{num_attention_heads}-heads'
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
dtype = torch.float16
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
|
||||
intermediate_size=256 * num_attention_heads * 4,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=num_attention_heads,
|
||||
initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test.
|
||||
)
|
||||
config = llama_config_to_gpt2_config(llama_config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# Create a shared test model.
|
||||
if rank == 0:
|
||||
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Run the standard forward pass test.
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format="hf"
|
||||
)
|
||||
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
|
||||
# TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
|
||||
out = model.transformer(input_ids)
|
||||
out, _ = all_gather_raw(out, process_group=process_group)
|
||||
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
|
||||
logits = model(input_ids).logits
|
||||
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
|
||||
logits, _ = all_gather_raw(logits, process_group)
|
||||
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
|
||||
|
||||
if rank == 0:
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
|
||||
logits_ref = model_ref(input_ids).logits.to(device=device)
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
|
||||
logits_hf = model_hf(input_ids).logits.to(device=device)
|
||||
del model_hf
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(checkpoint_path / f'{model_name}-hf')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user