diff --git a/convert_hf_to_picotron.py b/convert_hf_to_picotron.py deleted file mode 100644 index 2d64239..0000000 --- a/convert_hf_to_picotron.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -torchrun --nproc_per_node=1 convert_hf_to_picotron.py --save_path smollm.pth -""" -import os -import argparse -from tqdm import tqdm -import torch, torch.distributed as dist -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import set_all_seed -import lovely_tensors as lt; lt.monkey_patch() - -from model import Llama -from distributed.process_group_manager import setup_process_group_manager - -def sanity_check_weights(model, model_hf, picotron_to_hf): - - total, fail = 0, 0 - - state_dict = model.state_dict() - state_dict_hf = model_hf.state_dict() - - for name, name_hf in picotron_to_hf.items(): - - param_hf = state_dict_hf[name_hf] - param = state_dict[name] - - total += 1 - try: - torch.testing.assert_close(param_hf, param, rtol=1e-10, atol=1e-10) - except AssertionError as e: - print(f"{name_hf} and {name} are not equal") - fail += 1 - - if fail == 0: - print("All parameters are equal") - else: - AssertionError(f"{fail}/{total} parameters are not equal") - -def sanity_check_generation(model, model_hf, model_name, prompt, max_new_tokens): - tokenizer = AutoTokenizer.from_pretrained(model_name) - - input_ids_hf = tokenizer.encode(prompt, return_tensors="pt").to(device=model_hf.device) - input_ids = input_ids_hf.clone().to(device=model_hf.device) - - for _ in range(max_new_tokens): - # picotron model - seq_len = input_ids.shape[1] - position_index = torch.arange(seq_len).view(1, -1).to(device=model_hf.device) - - logits = model(input_ids=input_ids, position_ids=position_index) - next_token = torch.argmax(logits, dim=-1) - input_ids = torch.cat([input_ids, next_token[:, -1].unsqueeze(-1)], dim=-1) - - # HF model - logits_hf = model_hf(input_ids_hf).logits - next_token_hf = torch.argmax(logits_hf[:, -1, :], dim=-1) - input_ids_hf = torch.cat([input_ids_hf, next_token_hf.unsqueeze(0)], dim=-1) - - # Assert logits are equal - torch.testing.assert_close(logits, logits_hf, atol=1e-4, rtol=1e-4) - - print("Input prompt:\n", prompt) - print("Reference model output:\n", tokenizer.decode(input_ids_hf[0], skip_special_tokens=True)) - print("picotron model output:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True)) - -def get_weights_mapping(model_hf, to_hf): - - hf_to_picotron = {} - - hf_to_picotron["model.embed_tokens.weight"] = "embedding.weight" - hf_to_picotron["model.norm.weight"] = "final_norm.weight" - hf_to_picotron["lm_head.weight"] = "final_proj.weight" - - for i in range(model_hf.config.num_hidden_layers): - # Attention - hf_to_picotron[f"model.layers.{i}.self_attn.q_proj.weight"] = f"decoder_layers.{i}.attention.q_proj.weight" - hf_to_picotron[f"model.layers.{i}.self_attn.k_proj.weight"] = f"decoder_layers.{i}.attention.k_proj.weight" - hf_to_picotron[f"model.layers.{i}.self_attn.v_proj.weight"] = f"decoder_layers.{i}.attention.v_proj.weight" - hf_to_picotron[f"model.layers.{i}.self_attn.o_proj.weight"] = f"decoder_layers.{i}.attention.o_proj.weight" - # MLP - hf_to_picotron[f"model.layers.{i}.mlp.gate_proj.weight"] = f"decoder_layers.{i}.mlp.gate_proj.weight" - hf_to_picotron[f"model.layers.{i}.mlp.up_proj.weight"] = f"decoder_layers.{i}.mlp.up_proj.weight" - hf_to_picotron[f"model.layers.{i}.mlp.down_proj.weight"] = f"decoder_layers.{i}.mlp.down_proj.weight" - - hf_to_picotron[f"model.layers.{i}.input_layernorm.weight"] = f"decoder_layers.{i}.norm_attn.weight" - hf_to_picotron[f"model.layers.{i}.post_attention_layernorm.weight"] = f"decoder_layers.{i}.norm_mlp.weight" - - # check if we have takens all keys from the reference model - for key in hf_to_picotron: - assert key in model_hf.state_dict(), f"{key} not found in reference model" - - if to_hf: - # Mapping from picotron to hf - picotron_to_hf = {v: k for k, v in hf_to_picotron.items()} - return picotron_to_hf - - return hf_to_picotron - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert HF llama weights to picotron") - parser.add_argument("--save_path", type=str, default="smollm.pth") - parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") - parser.add_argument("--prompt", type=str, default="My name is") - parser.add_argument("--max_new_tokens", type=int, default=50) - - args = parser.parse_args() - - local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) - - #TODO: add gloo backend for generation - dist.init_process_group(backend="nccl") - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - setup_process_group_manager(tp_size=1, pp_size=1, dp_size=1, cp_size=1) - set_all_seed(seed=42) - - model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) - - model = Llama(config=model_hf.config, device=device) - picotron_to_hf = get_weights_mapping(model_hf, to_hf=True) - - ref_state_dict = model_hf.state_dict() - - for name, param in tqdm( - model.named_parameters(), - total=len(list(model.named_parameters())), - desc="Converting", - ): - if name in picotron_to_hf: - ref_name = picotron_to_hf[name] - ref_param = ref_state_dict[ref_name] - param.data.copy_(ref_param) - - torch.save(model.state_dict(), args.save_path) - - new_model = Llama(config=model_hf.config, device=device) - new_model.load_state_dict(torch.load(args.save_path)) - - print("Sanity check weight ...") - sanity_check_weights(new_model, model_hf, picotron_to_hf) - print("Sanity check generation ...") - sanity_check_generation(new_model, model_hf, args.model_name, args.prompt, args.max_new_tokens) - print("Conversion successful") \ No newline at end of file diff --git a/convert_picotron_to_hf.py b/convert_picotron_to_hf.py deleted file mode 100644 index e69de29..0000000 diff --git a/generate.py b/generate.py deleted file mode 100644 index a589754..0000000 --- a/generate.py +++ /dev/null @@ -1,126 +0,0 @@ -#VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3 --load_path smollm.pth -import os -import argparse -import torch, torch.distributed as dist -from transformers import AutoTokenizer, AutoConfig, AutoTokenizer - -from utils import set_all_seed -import distributed.process_group_manager as pgm -from distributed.process_group_manager import setup_process_group_manager -from parallel.pipeline_parallel import PipelineParallel -from distributed.distributed_primtives import pipeline_communicate -from model import Llama - -def run_one_inference_step(model, batch, device, config) -> torch.Tensor: - if pgm.process_group_manager.pp_world_size == 1: - return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"]) - - batch_size = batch["input_ids"].shape[0] - seq_len = batch["input_ids"].shape[1] - tensor_shapes = (batch_size, seq_len, config.hidden_size) - - # Preallocate memory for output logits. - logits = None - if pgm.process_group_manager.pp_is_last_stage: - logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device) - - recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device) - - batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer - - output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"]) - - # Send output to the next stage. - pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device) - - # Copy logits. - if pgm.process_group_manager.pp_is_last_stage: - logits = output_tensor - - dist.barrier() - - return logits - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--load_path", type=str) - parser.add_argument("--pp_size", type=int, default=1) - parser.add_argument("--max_tokens", type=int, default=32) - args = parser.parse_args() - - local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) - - #TODO(fmom): add gloo backend for generation - dist.init_process_group(backend="nccl") - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1) - set_all_seed(seed=42) - - load2name = { - "smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct", - "llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "llama3-B.pth": "meta-llama/Meta-Llama-3-8B", - } - - config = AutoConfig.from_pretrained(load2name[args.load_path]) - - base_model = Llama(config=config, device=device) - base_model.load_state_dict(torch.load(args.load_path, map_location="cpu")) - model = PipelineParallel(base_model, config).to(device) - - del base_model - model.eval() - - # Tokenize the input - prompts = [ - "My name is", - "How old are you ?", - "What is your favorite color?", - ] - - tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path]) - tokenizer.padding_side = "left" - tokenizer.pad_token = tokenizer.eos_token - - tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device) - - for _ in range(args.max_tokens): - - # Create the batch - seq_len = tokenized_prompts["input_ids"].shape[1] - position_ids = torch.arange(seq_len).view(1, -1) - - batch_prompts = { - "input_ids": tokenized_prompts["input_ids"].to(device=device), - "target_ids": None, - "position_ids": position_ids.to(device=device), - "attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device), - "hidden_states": None, - } - - logits = run_one_inference_step(model, batch_prompts, device, config) - - # Sample new token - if pgm.process_group_manager.pp_is_last_stage: - assert logits is not None - next_token = torch.argmax(logits[:, -1], dim=-1) - tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1) - tokenized_prompts["attention_mask"] = torch.cat([tokenized_prompts["attention_mask"], torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device)], dim=-1) - else: - tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device) - tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device) - - dist.broadcast(tokenized_prompts["input_ids"], src=pgm.process_group_manager.pp_last_rank) - dist.broadcast(tokenized_prompts["attention_mask"], src=pgm.process_group_manager.pp_last_rank) - - # Get only the new generated tokens - if pgm.process_group_manager.pp_is_last_stage: - for i, prompt in enumerate(prompts): - tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:] - outputs = tokenizer.decode(tokenized_outputs) - - print(f"Input: {prompt}") - print(f"Output: {outputs}") - print("------") - \ No newline at end of file