diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index 48e2f2f..bc2df84 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): class CrossEntropyLoss(nn.Module): def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False): + inplace_backward=False, process_group=None): super().__init__() if reduction not in ['mean', 'none']: raise NotImplementedError("Only support reduction = 'mean' or 'none'") @@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module): self.reduction = reduction self.label_smoothing = label_smoothing self.inplace_backward = inplace_backward + self.process_group = process_group - def forward(self, input, target, process_group=None): + def forward(self, input, target): assert input.is_cuda and target.is_cuda # SoftmaxCrossEntropyLoss implicitly casts to float loss = SoftmaxCrossEntropyLossFn.apply( input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, - process_group + self.process_group ) if self.reduction == 'mean': return loss.sum() / (target != self.ignore_index).sum() diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index b08769e..1c013d0 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -28,6 +28,7 @@ from flash_attn.modules.block import Block from flash_attn.modules.embedding import BertEmbeddings from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import index_first_axis, index_first_axis_residual +from flash_attn.utils.pretrained import state_dict_from_pretrained try: from flash_attn.ops.fused_dense import FusedDense @@ -439,12 +440,6 @@ class BertForPreTraining(BertPreTrainedModel): ) -def state_dict_from_pretrained(model_name): - from transformers.utils import WEIGHTS_NAME - from transformers.utils.hub import cached_file - return torch.load(cached_file(model_name, WEIGHTS_NAME)) - - def remap_state_dict(state_dict, config): # LayerNorm def key_mapping_ln_gamma_beta(key): diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 6722921..16c8a28 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -87,11 +87,15 @@ def sync_sequence_parallel_params(model: torch.nn.Module, process_group: Process ) +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): # We want to iterate over parameters with _sequence_parallel=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). params_seqparallel = {name: p for name, p in model.named_parameters() if getattr(p, '_sequence_parallel', False)} - for _, p in sorted(params_seqparallel.items()): - with torch.no_grad(): - torch.distributed.all_reduce(p.grad, group=process_group) + grads = [p.grad for _, p in sorted(params_seqparallel.items())] + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/flash_attn/utils/pretrained.py b/flash_attn/utils/pretrained.py new file mode 100644 index 0000000..2547f0b --- /dev/null +++ b/flash_attn/utils/pretrained.py @@ -0,0 +1,8 @@ +import torch + +from transformers.utils import WEIGHTS_NAME +from transformers.utils.hub import cached_file + + +def state_dict_from_pretrained(model_name): + return torch.load(cached_file(model_name, WEIGHTS_NAME)) diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index ac49bb3..bb2043b 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -24,7 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @pytest.mark.parametrize('vocab_size', [50264]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) -def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype): +def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype): assert vocab_size % world_size == 0 rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py index 499cd50..61525be 100644 --- a/tests/models/test_bert.py +++ b/tests/models/test_bert.py @@ -12,8 +12,8 @@ from transformers.models.bert.modeling_bert import BertModel as BertModelHF from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from flash_attn.models.bert import BertModel, BertForPreTraining -from flash_attn.models.bert import state_dict_from_pretrained from flash_attn.models.bert import remap_state_dict +from flash_attn.utils.pretrained import state_dict_from_pretrained @pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])