diff --git a/picotron/utils.py b/picotron/utils.py index 730be78..2bf5de2 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -34,13 +34,15 @@ def to_readable_format(num, precision=2): else: return f"{num:.{precision}f}" -# ref: https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289 -def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989 * 10 ** 12): +# ref: +# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289 +# https://github.com/stanford-cs336/spring2024-lectures/blob/main/lecture_02.py#L950 +def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989.5 * 10 ** 12): num_layers = model_config.num_hidden_layers hidden_dim = model_config.hidden_size seq_len = model_config.max_position_embeddings - flops_per_toke = 6 * num_params + 12 * num_layers * hidden_dim * seq_len - mfu = tokens_per_second * flops_per_toke / theoretical_flops * 100 # percentage + flops_per_token = 6 * num_params + 12 * num_layers * hidden_dim * seq_len + mfu = tokens_per_second * flops_per_token / theoretical_flops * 100 # percentage return mfu def get_num_params(model):