remove clone() in tp communications as torch.compile will optimize this out anyway

This commit is contained in:
ferdinand.mom 2024-12-03 16:26:41 +00:00
parent 09dfd1676f
commit aaa4a083e9

View File

@ -15,10 +15,8 @@ class Reduce(torch.autograd.Function):
def forward(ctx, input):
if pgm.process_group_manager.tp_world_size == 1:
return input
# Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446
output = input.clone()
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return output
dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return input
@staticmethod
def backward(ctx, grad_output):
@ -57,7 +55,5 @@ class Copy(torch.autograd.Function):
def backward(ctx, grad_output):
if pgm.process_group_manager.tp_world_size == 1:
return grad_output
# Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446
grad = grad_output.clone()
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return grad
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return grad_output