diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index a8ad378..f4dbfa3 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -15,10 +15,8 @@ class Reduce(torch.autograd.Function): def forward(ctx, input): if pgm.process_group_manager.tp_world_size == 1: return input - # Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446 - output = input.clone() - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return output + dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return input @staticmethod def backward(ctx, grad_output): @@ -57,7 +55,5 @@ class Copy(torch.autograd.Function): def backward(ctx, grad_output): if pgm.process_group_manager.tp_world_size == 1: return grad_output - # Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446 - grad = grad_output.clone() - dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return grad \ No newline at end of file + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return grad_output \ No newline at end of file