fix stuff to make it CPU compliants

This commit is contained in:
ferdinand.mom 2024-12-18 16:50:36 +00:00
parent b49ddac4b4
commit b647f58289
7 changed files with 66 additions and 6 deletions

View File

@ -1 +1,44 @@
# picotron
# picotron
![](assets/banière.png)
- The minimalist & most-hackable repository for pre-training Llama-like models with 4D Parallelism (Data, Tensor, Pipeline, Context parallel). It is a rewrite of [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) for educational purpose. The code itself is plain and readable: **train.py, model.py and \[data|tensor|pipeline|context\]_parallel.py are all < 300 LOC**.
- Performance is not yet in okay-ish but this is under active development.
# Install
```
pip install -e .
```
# Quick start
- GPU
```sh
# DP=8
python create_config.py --out_dir tmp --exp_name llama-1B --dp 8 --model_name HuggingFaceTB/SmolLM-1.7B --num_hidden_layers 15 --grad_acc_steps 32 --mbs 4 --seq_len 1024 --hf_token <HF_TOKEN>
# Locally
torchrun --nproc_per_node 8 train.py --config tmp/llama-1B/config.json
# 3D Parallelism
python create_config.py --out_dir tmp --dp 4 --tp 2 --pp 2 --pp_engine 1f1b --exp_name llama-7B --model_name meta-llama/Llama-2-7b-hf --grad_acc_steps 32 --mbs 4 --seq_len 1024 --hf_token <HF_TOKEN>
# Slurm
python submit_slurm_jobs.py --inp_dir tmp/llama-7B --qos high --hf_token <HF_TOKEN>
```
- CPU (expect it to be slow)
```sh
# 3D Parallelism on CPU
python create_config.py --out_dir tmp --exp_name llama-1B-cpu --dp 2 --tp 2 --pp 2 --pp_engine 1f1b --model_name HuggingFaceTB/SmolLM-1.7B --num_hidden_layers 5 --grad_acc_steps 2 --mbs 4 --seq_len 128 --hf_token <HF_TOKEN> --use_cpu
# Locally
torchrun --nproc_per_node 8 train.py --config tmp/llama-1B-cpu/config.json
```
# Acknowledgements
- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)

BIN
assets/banière.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

View File

@ -126,6 +126,7 @@ def create_single_config(
seq_len: int,
exp_name: str,
use_wandb: bool = False,
use_cpu: bool = False,
use_fused_adam: bool = False,
hf_token: str = None
):
@ -156,6 +157,10 @@ def create_single_config(
config_content['distributed']['dp_size'] = dp
config_content['distributed']['pp_size'] = pp
config_content['distributed']['pp_engine'] = pp_engine
config_content['distributed']['use_cpu'] = use_cpu
if use_cpu:
config_content["environment"]["FLASH_ATTEN"] = "0"
config_content["distributed"]["backend"] = "gloo"
config_content['logging']['use_wandb'] = use_wandb
config_content['logging']['run_name'] = exp_name
@ -192,6 +197,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
parser.add_argument("--use_cpu", action="store_true", help="Use CPU for training")
parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam")
parser.add_argument("--hf_token", type=str, help="HF token")
@ -213,6 +219,7 @@ if __name__ == "__main__":
seq_len=args.seq_len,
exp_name=args.exp_name,
use_wandb=args.use_wandb,
use_cpu=args.use_cpu,
use_fused_adam=args.use_fused_adam,
hf_token=args.hf_token
)

View File

@ -68,6 +68,7 @@ class BucketManager:
grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
"""
self.params = list(params) # Convert parameter generator to a list.
self.device = torch.device("cuda") if self.params[0].is_cuda else torch.device("cpu")
self.buckets = [] # List of buckets.
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
@ -116,7 +117,7 @@ class BucketManager:
# Create tensors for storing gradients and initialize Bucket objects.
for i in range(len(bucket_sizes)):
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device=self.device))
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
# Create gradient views for each parameter.

View File

@ -1,6 +1,12 @@
from setuptools import setup, find_packages
def read_requirements():
with open('requirements.txt') as req:
return [line.strip() for line in req if line.strip() and not line.startswith('#')]
setup(
name="picotron", # Name of the package
name="picotron",
version='0.1.0',
packages=find_packages(), # Automatically find packages in the current directory
packages=find_packages(),
install_requires=read_requirements(),
)

View File

@ -30,7 +30,7 @@
"dataset": {
"name": "roneneldan/TinyStories",
"num_workers": 0,
"num_proc": 4
"num_proc": 1
},
"checkpoint": {
"save_dir": "ckpt",

View File

@ -159,6 +159,8 @@ if __name__ == "__main__":
dist.barrier()
print(f"rank {pgm.process_group_manager.global_rank}: Initializing model meta device", is_print_rank=is_wandb_rank)
start_time = time.time()
with init_model_with_dematerialized_weights():
@ -209,7 +211,8 @@ if __name__ == "__main__":
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
reduced_loss /= pgm.process_group_manager.cp_dp_world_size
return reduced_loss.item()
while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]: