diff --git a/picotron/data.py b/picotron/data.py index 959fa6d..922bde7 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -10,8 +10,7 @@ from picotron.utils import print import picotron.process_group_manager as pgm class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None): - + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.grad_acc_steps = grad_acc_steps @@ -50,7 +49,7 @@ class MicroBatchDataLoader(DataLoader): self.tokenized_dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, - pin_memory=True, + pin_memory=pin_memory, num_workers=num_workers, sampler=self.sampler, shuffle=False @@ -109,14 +108,11 @@ class MicroBatchDataLoader(DataLoader): input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() - local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) - attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() return { "input_ids": input_ids, "target_ids": target_ids, "position_ids": position_ids, - "attn_mask": attn_mask, "hidden_states": None } @@ -131,6 +127,12 @@ class MicroBatchDataLoader(DataLoader): try: batch = next(self._iterator) except StopIteration: - self._iterator = None - raise StopIteration + # Reinitialize the sampler and iterator + self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0) + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration return batch \ No newline at end of file diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index cc9328d..634c869 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import picotron.process_group_manager as pgm -from picotron.tensor_parallel.tp_communications import Reduce, Gather, Copy, split_tensor_along_last_dim +from picotron.tensor_parallel.tp_communications import Reduce, Gather, linear_with_all_reduce, linear_with_async_all_reduce def apply_tensor_parallel(model): @@ -51,10 +51,26 @@ def apply_tensor_parallel(model): return model -class ColumnParallelLinear(nn.Module): +class ColumnParallelLinear(torch.nn.Module): + """Column Parallel Linear layer + Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p] + This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension. + Arguments: + in_features: first dimension of weight matrix W. + out_features: second dimension of weight matrix W. + bias: If true, add bias + init_method: method to initialize weights + gather_output: If true, gather the output from all the partitions. This is used for the last linear layer + """ - def __init__(self, in_features: int, out_features: int, bias: bool, gather_output: bool = False): - + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + gather_output: bool = False, + async_all_reduce: bool = True, + ) -> None: super(ColumnParallelLinear, self).__init__() self.tp_world_size = pgm.process_group_manager.tp_world_size @@ -65,7 +81,8 @@ class ColumnParallelLinear(nn.Module): assert out_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" self.output_size_per_partition = out_features // self.tp_world_size self.gather_output = gather_output - + self.async_all_reduce = async_all_reduce + # Allocate space for the weight and bias # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i if bias: @@ -95,17 +112,34 @@ class ColumnParallelLinear(nn.Module): # Split the model into size of self.output_size_per_partition weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0) self.weight.data = weight_list[self.tp_rank].contiguous() - - def forward(self, input): - input_parallel = Copy.apply(input) - # XW_i^T + b, output is Y_i - output = F.linear(input_parallel, self.weight, self.bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.async_all_reduce: + output = linear_with_async_all_reduce(input, self.weight, self.bias) + else: + output = linear_with_all_reduce(input, self.weight, self.bias) if self.gather_output: output = Gather.apply(output) return output class RowParallelLinear(nn.Module): - + """Linear layer with row parallelism. + Y = XW + b. W is parallelized along its first dimension and X along its second dimension as: + - - + | W_1 | + | . | + W = | . | X = [X_1, ..., X_p] + | . | + | W_p | + - - + We assume that X is already parallelized. This is the case after ColumnParallelLinear. + This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method. + Arguments: + in_features: first dimension of matrix W. + out_features: second dimension of matrix W. + bias: If true, add bias + init_method: method to initialize weights. + """ def __init__(self, in_features: int, out_features: int, bias: bool): super(RowParallelLinear, self).__init__() diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index f4dbfa3..9e7cf1e 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -1,6 +1,13 @@ import torch.distributed as dist import torch import picotron.process_group_manager as pgm +import torch.nn.functional as F + +from typing import Tuple + +def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Merge the first two dimensions of tensors.""" + return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:]) def split_tensor_along_last_dim(tensor, num_partitions): """Split a tensor along its last dimension into num_partitions chunks.""" @@ -45,7 +52,7 @@ class Gather(torch.autograd.Function): chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size) return chunks[pgm.process_group_manager.tp_rank].contiguous() -class Copy(torch.autograd.Function): +class Identity(torch.autograd.Function): """Identity in forward pass, all-reduce in backward pass.""" @staticmethod def forward(ctx, input): @@ -56,4 +63,41 @@ class Copy(torch.autograd.Function): if pgm.process_group_manager.tp_world_size == 1: return grad_output dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return grad_output \ No newline at end of file + return grad_output + +def linear_with_all_reduce(input, weight, bias): + input_parallel = Identity.apply(input) + output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i + return output + +def linear_with_async_all_reduce(input, weight, bias): + return LinearWithAsyncAllReduce.apply(input, weight, bias) + +class LinearWithAsyncAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, weight, bias): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t() + return output + + @staticmethod + def backward(ctx, grad_output): + """ + The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before + the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication + This is only applicable to Column Parallel Linear + + Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce + Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias + """ + input_, weight = ctx.saved_tensors + grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size) + # all-reduce input gradient. + input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True) + # merge first two dims to allow matrix multiplication + grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size) + grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size) + grad_bias = grad_output.sum(0) if ctx.use_bias else None + input_gradient_all_reduce_handle.wait() + return grad_input, grad_weight, grad_bias \ No newline at end of file diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 0000000..f865486 --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,213 @@ +""" +torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 test_dataloader.py +""" +from picotron.data import MicroBatchDataLoader +import torch.distributed as dist +import os +import datetime +from picotron.process_group_manager import setup_process_group_manager + +import torch +from torch.utils.data import DataLoader, DistributedSampler +import numpy as np +from functools import partial +from datasets import Features, Sequence, Value, load_dataset +from transformers import AutoTokenizer + +import picotron.process_group_manager as pgm + +# remove context parallelism split. as a reference +class DummyDataLoader(DataLoader): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): + self.micro_batch_size = micro_batch_size + self.seq_length = seq_length + self.grad_acc_steps = grad_acc_steps + self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + + self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: + self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + + # Tokenize and chunk the dataset + self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc) + + self.sampler = DistributedSampler( + self.tokenized_dataset, + num_replicas=pgm.process_group_manager.dp_world_size, + rank=pgm.process_group_manager.dp_rank, + shuffle=False + ) + + super().__init__( + self.tokenized_dataset, + batch_size=micro_batch_size, + collate_fn=self.collate_batch, + pin_memory=True, + num_workers=num_workers, + sampler=self.sampler, + shuffle=False + ) + + @staticmethod + def tokenizer_group_text(examples, tokenizer, sequence_length): + """Tokenize a list of texts and group them in chunks of sequence_length + 1""" + tokenized_text_batch = tokenizer.batch_encode_plus( + examples, + return_attention_mask=False, + return_token_type_ids=False, + return_tensors='np' + ) + concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} + total_length = len(concatenated_tokens['input_ids']) + if total_length >= sequence_length + 1: + total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 + result = { + 'input_ids': [ + concatenated_tokens['input_ids'][i : i + sequence_length + 1] + for i in range(0, total_length - sequence_length, sequence_length) + ] + } + return result + + def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): + """Tokenize the dataset and group texts in chunks of sequence_length + 1""" + # Create a partial function with fixed arguments + tokenizer_func = partial( + self.tokenizer_group_text, + tokenizer=self.tokenizer, + sequence_length=sequence_length + ) + + tokenized_dataset = dataset.map( + tokenizer_func, + input_columns=text_column_name, + remove_columns=dataset.column_names, + features=Features({ + "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) + }), + batched=True, + num_proc=num_proc, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {sequence_length+1}", + ) + + return tokenized_dataset + + def collate_batch(self, batch): + batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) + batch_size = batch_input_ids.size(0) + input_ids = batch_input_ids[:, :self.seq_length].contiguous() + target_ids = batch_input_ids[:, 1:self.seq_length+1].contiguous() + position_ids = torch.arange(0, self.seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "position_ids": position_ids, + "hidden_states": None + } + + def __iter__(self): + if self._iterator is None: + self._iterator = super().__iter__() + return self + + def __next__(self): + if self._iterator is None: + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + # Reinitialize the sampler and iterator + self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0) + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration + return batch + +# test the tokens are split correctly in context parallelism +# TODO: test zigzag behavior +def test_cp_behavior(TP_SIZE, CP_SIZE, PP_SIZE, DP_SIZE, SEQ_LEN=8): + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + backend = "nccl" + + assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) + setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) + + data_loader = MicroBatchDataLoader( + micro_batch_size=2, + seq_length=SEQ_LEN, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=10, + pin_memory=False + ) + + ref_data_loader = DummyDataLoader( + micro_batch_size=2, + seq_length=SEQ_LEN, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=10, + pin_memory=False + ) + + for i in range(1): + ref_batch = next(ref_data_loader) + batch = next(data_loader) + split_size = ref_batch["input_ids"].shape[1] // pgm.process_group_manager.cp_world_size + start_idx = split_size * global_rank + end_idx = start_idx + split_size + assert torch.equal(ref_batch["input_ids"][:,start_idx:end_idx], batch["input_ids"]), "input_ids are not equal" + +# test the infinite loop behavior +def test_infinite_loop(): + local_rank = 0 + global_rank = 0 + world_size = 1 + backend = "nccl" + + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) + setup_process_group_manager(tp_size=1, cp_size=1, pp_size=1, dp_size=1) + + data_loader = MicroBatchDataLoader( + micro_batch_size=2, + seq_length=256, + dataset_name="roneneldan/TinyStories", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + grad_acc_steps=1, + num_workers=1, + num_proc=1, + num_samples=2, + ) + + s = set() + for i in range(10): + batch = next(data_loader) + # Convert the nested list to a tuple of tuples + batch_tuple = tuple(tuple(x) for x in batch["input_ids"].tolist()) + if batch_tuple in s: + assert True + s.add(batch_tuple) + assert False + + +if __name__ == "__main__": + # test_infinite_loop() + test_cp_behavior(TP_SIZE=1, CP_SIZE=2, PP_SIZE=1, DP_SIZE=1, SEQ_LEN=8) \ No newline at end of file diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py new file mode 100644 index 0000000..65d26c0 --- /dev/null +++ b/tests/test_tensor_parallel.py @@ -0,0 +1,75 @@ +""" +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py +""" + +from picotron.process_group_manager import setup_process_group_manager +from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from picotron.utils import set_all_seed +import torch +import os +import torch.distributed as dist +import datetime +import picotron.process_group_manager as pgm + +local_rank = int(os.environ["LOCAL_RANK"]) +global_rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device("cuda", local_rank) + +dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3)) +setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1) + +set_all_seed(42) + +batch_size, seq_len = 2, 4 +input_size, output_size = 8, 16 +bias = True # linear layer with/without bias +async_all_reduce = False # async all-reduce or not for column parallel linear layer + +# Initialize input tensor +tensor_shape = (batch_size, seq_len, input_size) +tensor = torch.randn(tensor_shape, device=device, requires_grad=True) +column_parallel_tensor = tensor.clone().detach().requires_grad_(True) +row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True) + +# Initialize column/row parallel layers +column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device) +row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device) +linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device) + +# copy weight and bias from reference linear layer to column/row parallel layers +column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank]) +row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank]) +if bias: + column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank]) + row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias) + +### forward pass ### +output_reference = linear_layer(tensor) +output_column_parallel = column_parallel_linear(column_parallel_tensor) +output_row_parallel = row_parallel_linear(row_parallel_tensor) + +# check forward output consistency +assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference" +torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue + +### backward pass ### +output_reference.backward(torch.ones_like(output_reference)) +output_column_parallel.backward(torch.ones_like(output_column_parallel)) +output_row_parallel.backward(torch.ones_like(output_row_parallel)) + +# check backward weight gradient, bias gradient, and input gradient consistency +# column parallel linear test +torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad) +torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad) +if bias: + torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad) + +# row parallel linear test +torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad) +torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad) +if bias: + torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad) + +print(f"Rank {dist.get_rank()}: All tests passed") \ No newline at end of file diff --git a/train.py b/train.py index bb83b98..e79e66c 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,7 @@ """Training script for LLaMA model. CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json -CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 @@ -125,7 +125,7 @@ if __name__ == "__main__": if is_wandb_rank and config["logging"]["use_wandb"]: wandb.init( project="picotron", - name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}", + name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}", config={ "tensor_parallel_size": pgm.process_group_manager.tp_world_size, "context_parallel_size": pgm.process_group_manager.cp_world_size, @@ -238,7 +238,8 @@ if __name__ == "__main__": step_duration = time.time() - step_start_time tokens_per_second = tokens_per_step / step_duration - mfu = get_mfu(tokens_per_second / world_size, num_params, model_config) + tokens_per_second_per_gpu = tokens_per_second / world_size + mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config) if is_wandb_rank: print( @@ -247,7 +248,7 @@ if __name__ == "__main__": f"Loss: {loss:6.4f} | " f"Global batch size: {to_readable_format(tokens_per_step):>7s} | " f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | " - f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | " + f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | " f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''} | " f"MFU: {mfu:5.2f}% | " f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB", @@ -259,6 +260,8 @@ if __name__ == "__main__": "loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration, + "mfu": mfu, + "tokens_per_second_per_gpu": tokens_per_second_per_gpu, "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens })