mfu ref/typo

This commit is contained in:
zzhhjjj 2024-11-18 17:57:02 +00:00
parent a2ce795837
commit 16d85cdb3a

View File

@ -34,13 +34,15 @@ def to_readable_format(num, precision=2):
else:
return f"{num:.{precision}f}"
# ref: https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289
def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989 * 10 ** 12):
# ref:
# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289
# https://github.com/stanford-cs336/spring2024-lectures/blob/main/lecture_02.py#L950
def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989.5 * 10 ** 12):
num_layers = model_config.num_hidden_layers
hidden_dim = model_config.hidden_size
seq_len = model_config.max_position_embeddings
flops_per_toke = 6 * num_params + 12 * num_layers * hidden_dim * seq_len
mfu = tokens_per_second * flops_per_toke / theoretical_flops * 100 # percentage
flops_per_token = 6 * num_params + 12 * num_layers * hidden_dim * seq_len
mfu = tokens_per_second * flops_per_token / theoretical_flops * 100 # percentage
return mfu
def get_num_params(model):