remove clone() in tp communications as torch.compile will optimize this out anyway
This commit is contained in:
parent
09dfd1676f
commit
aaa4a083e9
@ -15,10 +15,8 @@ class Reduce(torch.autograd.Function):
|
||||
def forward(ctx, input):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return input
|
||||
# Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446
|
||||
output = input.clone()
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return output
|
||||
dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@ -57,7 +55,5 @@ class Copy(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return grad_output
|
||||
# Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446
|
||||
grad = grad_output.clone()
|
||||
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad
|
||||
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad_output
|
||||
Loading…
Reference in New Issue
Block a user