remove redundancy
This commit is contained in:
parent
16d85cdb3a
commit
ebef9a36e3
@ -82,9 +82,6 @@ class DataParallelBucket(nn.Module):
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
def get_flops(self, *args, **kwargs):
|
||||
return self.module.get_flops(*args, **kwargs)
|
||||
|
||||
def register_backward_hook(self):
|
||||
"""
|
||||
Registers a backward hook to manually accumulate and synchronize gradients.
|
||||
|
||||
@ -201,13 +201,4 @@ class Llama(nn.Module):
|
||||
x = self.final_norm(x)
|
||||
logits = self.final_proj(x)
|
||||
|
||||
return logits # [batch_size, seq_length, vocab_size]
|
||||
|
||||
# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289-L303
|
||||
# TODO: Need to check the formula.
|
||||
def get_flops(self, fwdbwd_per_iter, dt, num_params):
|
||||
L, H, T = self.num_layers , self.hidden_size, self.max_position_embeddings
|
||||
flops_per_fwdbwd = 6 * num_params * T + 12* L* H* T ** 2
|
||||
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
||||
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
||||
return flops_achieved
|
||||
return logits # [batch_size, seq_length, vocab_size]
|
||||
Loading…
Reference in New Issue
Block a user