fix stuff to make it CPU compliants
This commit is contained in:
parent
b49ddac4b4
commit
b647f58289
45
README.md
45
README.md
@ -1 +1,44 @@
|
||||
# picotron
|
||||
# picotron
|
||||
|
||||

|
||||
|
||||
- The minimalist & most-hackable repository for pre-training Llama-like models with 4D Parallelism (Data, Tensor, Pipeline, Context parallel). It is a rewrite of [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) for educational purpose. The code itself is plain and readable: **train.py, model.py and \[data|tensor|pipeline|context\]_parallel.py are all < 300 LOC**.
|
||||
|
||||
- Performance is not yet in okay-ish but this is under active development.
|
||||
|
||||
# Install
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
# Quick start
|
||||
|
||||
- GPU
|
||||
```sh
|
||||
# DP=8
|
||||
python create_config.py --out_dir tmp --exp_name llama-1B --dp 8 --model_name HuggingFaceTB/SmolLM-1.7B --num_hidden_layers 15 --grad_acc_steps 32 --mbs 4 --seq_len 1024 --hf_token <HF_TOKEN>
|
||||
|
||||
# Locally
|
||||
torchrun --nproc_per_node 8 train.py --config tmp/llama-1B/config.json
|
||||
|
||||
# 3D Parallelism
|
||||
python create_config.py --out_dir tmp --dp 4 --tp 2 --pp 2 --pp_engine 1f1b --exp_name llama-7B --model_name meta-llama/Llama-2-7b-hf --grad_acc_steps 32 --mbs 4 --seq_len 1024 --hf_token <HF_TOKEN>
|
||||
|
||||
# Slurm
|
||||
python submit_slurm_jobs.py --inp_dir tmp/llama-7B --qos high --hf_token <HF_TOKEN>
|
||||
```
|
||||
|
||||
- CPU (expect it to be slow)
|
||||
|
||||
```sh
|
||||
# 3D Parallelism on CPU
|
||||
python create_config.py --out_dir tmp --exp_name llama-1B-cpu --dp 2 --tp 2 --pp 2 --pp_engine 1f1b --model_name HuggingFaceTB/SmolLM-1.7B --num_hidden_layers 5 --grad_acc_steps 2 --mbs 4 --seq_len 128 --hf_token <HF_TOKEN> --use_cpu
|
||||
|
||||
# Locally
|
||||
torchrun --nproc_per_node 8 train.py --config tmp/llama-1B-cpu/config.json
|
||||
```
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
|
||||
BIN
assets/banière.png
Normal file
BIN
assets/banière.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
@ -126,6 +126,7 @@ def create_single_config(
|
||||
seq_len: int,
|
||||
exp_name: str,
|
||||
use_wandb: bool = False,
|
||||
use_cpu: bool = False,
|
||||
use_fused_adam: bool = False,
|
||||
hf_token: str = None
|
||||
):
|
||||
@ -156,6 +157,10 @@ def create_single_config(
|
||||
config_content['distributed']['dp_size'] = dp
|
||||
config_content['distributed']['pp_size'] = pp
|
||||
config_content['distributed']['pp_engine'] = pp_engine
|
||||
config_content['distributed']['use_cpu'] = use_cpu
|
||||
if use_cpu:
|
||||
config_content["environment"]["FLASH_ATTEN"] = "0"
|
||||
config_content["distributed"]["backend"] = "gloo"
|
||||
|
||||
config_content['logging']['use_wandb'] = use_wandb
|
||||
config_content['logging']['run_name'] = exp_name
|
||||
@ -192,6 +197,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
|
||||
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
|
||||
parser.add_argument("--use_cpu", action="store_true", help="Use CPU for training")
|
||||
parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam")
|
||||
parser.add_argument("--hf_token", type=str, help="HF token")
|
||||
|
||||
@ -213,6 +219,7 @@ if __name__ == "__main__":
|
||||
seq_len=args.seq_len,
|
||||
exp_name=args.exp_name,
|
||||
use_wandb=args.use_wandb,
|
||||
use_cpu=args.use_cpu,
|
||||
use_fused_adam=args.use_fused_adam,
|
||||
hf_token=args.hf_token
|
||||
)
|
||||
|
||||
@ -68,6 +68,7 @@ class BucketManager:
|
||||
grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
|
||||
"""
|
||||
self.params = list(params) # Convert parameter generator to a list.
|
||||
self.device = torch.device("cuda") if self.params[0].is_cuda else torch.device("cpu")
|
||||
self.buckets = [] # List of buckets.
|
||||
self.process_group = process_group
|
||||
self.process_group_size = dist.get_world_size(group=self.process_group)
|
||||
@ -116,7 +117,7 @@ class BucketManager:
|
||||
|
||||
# Create tensors for storing gradients and initialize Bucket objects.
|
||||
for i in range(len(bucket_sizes)):
|
||||
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
|
||||
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device=self.device))
|
||||
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
|
||||
|
||||
# Create gradient views for each parameter.
|
||||
|
||||
10
setup.py
10
setup.py
@ -1,6 +1,12 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
def read_requirements():
|
||||
with open('requirements.txt') as req:
|
||||
return [line.strip() for line in req if line.strip() and not line.startswith('#')]
|
||||
|
||||
setup(
|
||||
name="picotron", # Name of the package
|
||||
name="picotron",
|
||||
version='0.1.0',
|
||||
packages=find_packages(), # Automatically find packages in the current directory
|
||||
packages=find_packages(),
|
||||
install_requires=read_requirements(),
|
||||
)
|
||||
@ -30,7 +30,7 @@
|
||||
"dataset": {
|
||||
"name": "roneneldan/TinyStories",
|
||||
"num_workers": 0,
|
||||
"num_proc": 4
|
||||
"num_proc": 1
|
||||
},
|
||||
"checkpoint": {
|
||||
"save_dir": "ckpt",
|
||||
|
||||
5
train.py
5
train.py
@ -159,6 +159,8 @@ if __name__ == "__main__":
|
||||
|
||||
dist.barrier()
|
||||
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Initializing model meta device", is_print_rank=is_wandb_rank)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with init_model_with_dematerialized_weights():
|
||||
@ -209,7 +211,8 @@ if __name__ == "__main__":
|
||||
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
reduced_loss /= pgm.process_group_manager.cp_dp_world_size
|
||||
return reduced_loss.item()
|
||||
|
||||
while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user