From aaa4a083e92fd3b71fb7f68f83431bdc082dc494 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 3 Dec 2024 16:26:41 +0000 Subject: [PATCH] remove clone() in tp communications as torch.compile will optimize this out anyway --- picotron/tensor_parallel/tp_communications.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index a8ad378..f4dbfa3 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -15,10 +15,8 @@ class Reduce(torch.autograd.Function): def forward(ctx, input): if pgm.process_group_manager.tp_world_size == 1: return input - # Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446 - output = input.clone() - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return output + dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return input @staticmethod def backward(ctx, grad_output): @@ -57,7 +55,5 @@ class Copy(torch.autograd.Function): def backward(ctx, grad_output): if pgm.process_group_manager.tp_world_size == 1: return grad_output - # Need to clone apparently: https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L446 - grad = grad_output.clone() - dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return grad \ No newline at end of file + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return grad_output \ No newline at end of file