remove redundancy

This commit is contained in:
zzhhjjj 2024-11-20 01:58:44 +00:00
parent 16d85cdb3a
commit ebef9a36e3
2 changed files with 1 additions and 13 deletions

View File

@ -82,9 +82,6 @@ class DataParallelBucket(nn.Module):
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
def get_flops(self, *args, **kwargs):
return self.module.get_flops(*args, **kwargs)
def register_backward_hook(self):
"""
Registers a backward hook to manually accumulate and synchronize gradients.

View File

@ -201,13 +201,4 @@ class Llama(nn.Module):
x = self.final_norm(x)
logits = self.final_proj(x)
return logits # [batch_size, seq_length, vocab_size]
# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289-L303
# TODO: Need to check the formula.
def get_flops(self, fwdbwd_per_iter, dt, num_params):
L, H, T = self.num_layers , self.hidden_size, self.max_position_embeddings
flops_per_fwdbwd = 6 * num_params * T + 12* L* H* T ** 2
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
flops_achieved = flops_per_iter * (1.0/dt) # per second
return flops_achieved
return logits # [batch_size, seq_length, vocab_size]