Merge branch 'main' into async_tp
This commit is contained in:
commit
ca1fcec87f
@ -82,9 +82,6 @@ class DataParallelBucket(nn.Module):
|
|||||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||||
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
|
|
||||||
def get_flops(self, *args, **kwargs):
|
|
||||||
return self.module.get_flops(*args, **kwargs)
|
|
||||||
|
|
||||||
def register_backward_hook(self):
|
def register_backward_hook(self):
|
||||||
"""
|
"""
|
||||||
Registers a backward hook to manually accumulate and synchronize gradients.
|
Registers a backward hook to manually accumulate and synchronize gradients.
|
||||||
|
|||||||
@ -202,12 +202,3 @@ class Llama(nn.Module):
|
|||||||
logits = self.final_proj(x)
|
logits = self.final_proj(x)
|
||||||
|
|
||||||
return logits # [batch_size, seq_length, vocab_size]
|
return logits # [batch_size, seq_length, vocab_size]
|
||||||
|
|
||||||
# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289-L303
|
|
||||||
# TODO: Need to check the formula.
|
|
||||||
def get_flops(self, fwdbwd_per_iter, dt, num_params):
|
|
||||||
L, H, T = self.num_layers , self.hidden_size, self.max_position_embeddings
|
|
||||||
flops_per_fwdbwd = 6 * num_params * T + 12* L* H* T ** 2
|
|
||||||
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
|
||||||
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
|
||||||
return flops_achieved
|
|
||||||
Loading…
Reference in New Issue
Block a user