Tweak CrossEntropyLoss to take process_group in init
This commit is contained in:
parent
b4018a5028
commit
c6ecd40a59
@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
inplace_backward=False, process_group=None):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, input, target, process_group=None):
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
||||
process_group
|
||||
self.process_group
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
|
||||
@ -28,6 +28,7 @@ from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
@ -439,12 +440,6 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name):
|
||||
from transformers.utils import WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
return torch.load(cached_file(model_name, WEIGHTS_NAME))
|
||||
|
||||
|
||||
def remap_state_dict(state_dict, config):
|
||||
# LayerNorm
|
||||
def key_mapping_ln_gamma_beta(key):
|
||||
|
||||
@ -87,11 +87,15 @@ def sync_sequence_parallel_params(model: torch.nn.Module, process_group: Process
|
||||
)
|
||||
|
||||
|
||||
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
||||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
for _, p in sorted(params_seqparallel.items()):
|
||||
with torch.no_grad():
|
||||
torch.distributed.all_reduce(p.grad, group=process_group)
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
8
flash_attn/utils/pretrained.py
Normal file
8
flash_attn/utils/pretrained.py
Normal file
@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
from transformers.utils import WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name):
|
||||
return torch.load(cached_file(model_name, WEIGHTS_NAME))
|
||||
@ -24,7 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
@pytest.mark.parametrize('vocab_size', [50264])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype):
|
||||
def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
|
||||
assert vocab_size % world_size == 0
|
||||
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
|
||||
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
|
||||
|
||||
@ -12,8 +12,8 @@ from transformers.models.bert.modeling_bert import BertModel as BertModelHF
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
|
||||
|
||||
from flash_attn.models.bert import BertModel, BertForPreTraining
|
||||
from flash_attn.models.bert import state_dict_from_pretrained
|
||||
from flash_attn.models.bert import remap_state_dict
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user