Tweak CrossEntropyLoss to take process_group in init

This commit is contained in:
Tri Dao 2022-12-27 09:49:59 -08:00
parent b4018a5028
commit c6ecd40a59
6 changed files with 22 additions and 14 deletions

View File

@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False): inplace_backward=False, process_group=None):
super().__init__() super().__init__()
if reduction not in ['mean', 'none']: if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'") raise NotImplementedError("Only support reduction = 'mean' or 'none'")
@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward self.inplace_backward = inplace_backward
self.process_group = process_group
def forward(self, input, target, process_group=None): def forward(self, input, target):
assert input.is_cuda and target.is_cuda assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float # SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply( loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
process_group self.process_group
) )
if self.reduction == 'mean': if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum() return loss.sum() / (target != self.ignore_index).sum()

View File

@ -28,6 +28,7 @@ from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
from flash_attn.utils.pretrained import state_dict_from_pretrained
try: try:
from flash_attn.ops.fused_dense import FusedDense from flash_attn.ops.fused_dense import FusedDense
@ -439,12 +440,6 @@ class BertForPreTraining(BertPreTrainedModel):
) )
def state_dict_from_pretrained(model_name):
from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
return torch.load(cached_file(model_name, WEIGHTS_NAME))
def remap_state_dict(state_dict, config): def remap_state_dict(state_dict, config):
# LayerNorm # LayerNorm
def key_mapping_ln_gamma_beta(key): def key_mapping_ln_gamma_beta(key):

View File

@ -87,11 +87,15 @@ def sync_sequence_parallel_params(model: torch.nn.Module, process_group: Process
) )
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order, # We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters() params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)} if getattr(p, '_sequence_parallel', False)}
for _, p in sorted(params_seqparallel.items()): grads = [p.grad for _, p in sorted(params_seqparallel.items())]
with torch.no_grad(): with torch.no_grad():
torch.distributed.all_reduce(p.grad, group=process_group) coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

View File

@ -0,0 +1,8 @@
import torch
from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name):
return torch.load(cached_file(model_name, WEIGHTS_NAME))

View File

@ -24,7 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('vocab_size', [50264]) @pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype): def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))

View File

@ -12,8 +12,8 @@ from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from flash_attn.models.bert import BertModel, BertForPreTraining from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import state_dict_from_pretrained
from flash_attn.models.bert import remap_state_dict from flash_attn.models.bert import remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])