From ebef9a36e34a9573422501ae07f82d091cce87b9 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Wed, 20 Nov 2024 01:58:44 +0000 Subject: [PATCH] remove redundancy --- picotron/data_parallel/data_parallel.py | 3 --- picotron/model.py | 11 +---------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/picotron/data_parallel/data_parallel.py b/picotron/data_parallel/data_parallel.py index aeedfa1..e009bd9 100644 --- a/picotron/data_parallel/data_parallel.py +++ b/picotron/data_parallel/data_parallel.py @@ -82,9 +82,6 @@ class DataParallelBucket(nn.Module): def backward(self, input_tensor, output_tensor, output_tensor_grad): return self.module.backward(input_tensor, output_tensor, output_tensor_grad) - def get_flops(self, *args, **kwargs): - return self.module.get_flops(*args, **kwargs) - def register_backward_hook(self): """ Registers a backward hook to manually accumulate and synchronize gradients. diff --git a/picotron/model.py b/picotron/model.py index 1e84b7e..6cc2859 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -201,13 +201,4 @@ class Llama(nn.Module): x = self.final_norm(x) logits = self.final_proj(x) - return logits # [batch_size, seq_length, vocab_size] - - # https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289-L303 - # TODO: Need to check the formula. - def get_flops(self, fwdbwd_per_iter, dt, num_params): - L, H, T = self.num_layers , self.hidden_size, self.max_position_embeddings - flops_per_fwdbwd = 6 * num_params * T + 12* L* H* T ** 2 - flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter - flops_achieved = flops_per_iter * (1.0/dt) # per second - return flops_achieved \ No newline at end of file + return logits # [batch_size, seq_length, vocab_size] \ No newline at end of file