remove unecessary files
This commit is contained in:
parent
41f49bb15f
commit
e7b4722160
@ -1,143 +0,0 @@
|
||||
"""
|
||||
torchrun --nproc_per_node=1 convert_hf_to_picotron.py --save_path smollm.pth
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import torch, torch.distributed as dist
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import set_all_seed
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
from model import Llama
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
|
||||
def sanity_check_weights(model, model_hf, picotron_to_hf):
|
||||
|
||||
total, fail = 0, 0
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_hf = model_hf.state_dict()
|
||||
|
||||
for name, name_hf in picotron_to_hf.items():
|
||||
|
||||
param_hf = state_dict_hf[name_hf]
|
||||
param = state_dict[name]
|
||||
|
||||
total += 1
|
||||
try:
|
||||
torch.testing.assert_close(param_hf, param, rtol=1e-10, atol=1e-10)
|
||||
except AssertionError as e:
|
||||
print(f"{name_hf} and {name} are not equal")
|
||||
fail += 1
|
||||
|
||||
if fail == 0:
|
||||
print("All parameters are equal")
|
||||
else:
|
||||
AssertionError(f"{fail}/{total} parameters are not equal")
|
||||
|
||||
def sanity_check_generation(model, model_hf, model_name, prompt, max_new_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
input_ids_hf = tokenizer.encode(prompt, return_tensors="pt").to(device=model_hf.device)
|
||||
input_ids = input_ids_hf.clone().to(device=model_hf.device)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# picotron model
|
||||
seq_len = input_ids.shape[1]
|
||||
position_index = torch.arange(seq_len).view(1, -1).to(device=model_hf.device)
|
||||
|
||||
logits = model(input_ids=input_ids, position_ids=position_index)
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
input_ids = torch.cat([input_ids, next_token[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# HF model
|
||||
logits_hf = model_hf(input_ids_hf).logits
|
||||
next_token_hf = torch.argmax(logits_hf[:, -1, :], dim=-1)
|
||||
input_ids_hf = torch.cat([input_ids_hf, next_token_hf.unsqueeze(0)], dim=-1)
|
||||
|
||||
# Assert logits are equal
|
||||
torch.testing.assert_close(logits, logits_hf, atol=1e-4, rtol=1e-4)
|
||||
|
||||
print("Input prompt:\n", prompt)
|
||||
print("Reference model output:\n", tokenizer.decode(input_ids_hf[0], skip_special_tokens=True))
|
||||
print("picotron model output:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True))
|
||||
|
||||
def get_weights_mapping(model_hf, to_hf):
|
||||
|
||||
hf_to_picotron = {}
|
||||
|
||||
hf_to_picotron["model.embed_tokens.weight"] = "embedding.weight"
|
||||
hf_to_picotron["model.norm.weight"] = "final_norm.weight"
|
||||
hf_to_picotron["lm_head.weight"] = "final_proj.weight"
|
||||
|
||||
for i in range(model_hf.config.num_hidden_layers):
|
||||
# Attention
|
||||
hf_to_picotron[f"model.layers.{i}.self_attn.q_proj.weight"] = f"decoder_layers.{i}.attention.q_proj.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.self_attn.k_proj.weight"] = f"decoder_layers.{i}.attention.k_proj.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.self_attn.v_proj.weight"] = f"decoder_layers.{i}.attention.v_proj.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.self_attn.o_proj.weight"] = f"decoder_layers.{i}.attention.o_proj.weight"
|
||||
# MLP
|
||||
hf_to_picotron[f"model.layers.{i}.mlp.gate_proj.weight"] = f"decoder_layers.{i}.mlp.gate_proj.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.mlp.up_proj.weight"] = f"decoder_layers.{i}.mlp.up_proj.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.mlp.down_proj.weight"] = f"decoder_layers.{i}.mlp.down_proj.weight"
|
||||
|
||||
hf_to_picotron[f"model.layers.{i}.input_layernorm.weight"] = f"decoder_layers.{i}.norm_attn.weight"
|
||||
hf_to_picotron[f"model.layers.{i}.post_attention_layernorm.weight"] = f"decoder_layers.{i}.norm_mlp.weight"
|
||||
|
||||
# check if we have takens all keys from the reference model
|
||||
for key in hf_to_picotron:
|
||||
assert key in model_hf.state_dict(), f"{key} not found in reference model"
|
||||
|
||||
if to_hf:
|
||||
# Mapping from picotron to hf
|
||||
picotron_to_hf = {v: k for k, v in hf_to_picotron.items()}
|
||||
return picotron_to_hf
|
||||
|
||||
return hf_to_picotron
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert HF llama weights to picotron")
|
||||
parser.add_argument("--save_path", type=str, default="smollm.pth")
|
||||
parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct")
|
||||
parser.add_argument("--prompt", type=str, default="My name is")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
|
||||
#TODO: add gloo backend for generation
|
||||
dist.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_process_group_manager(tp_size=1, pp_size=1, dp_size=1, cp_size=1)
|
||||
set_all_seed(seed=42)
|
||||
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
|
||||
|
||||
model = Llama(config=model_hf.config, device=device)
|
||||
picotron_to_hf = get_weights_mapping(model_hf, to_hf=True)
|
||||
|
||||
ref_state_dict = model_hf.state_dict()
|
||||
|
||||
for name, param in tqdm(
|
||||
model.named_parameters(),
|
||||
total=len(list(model.named_parameters())),
|
||||
desc="Converting",
|
||||
):
|
||||
if name in picotron_to_hf:
|
||||
ref_name = picotron_to_hf[name]
|
||||
ref_param = ref_state_dict[ref_name]
|
||||
param.data.copy_(ref_param)
|
||||
|
||||
torch.save(model.state_dict(), args.save_path)
|
||||
|
||||
new_model = Llama(config=model_hf.config, device=device)
|
||||
new_model.load_state_dict(torch.load(args.save_path))
|
||||
|
||||
print("Sanity check weight ...")
|
||||
sanity_check_weights(new_model, model_hf, picotron_to_hf)
|
||||
print("Sanity check generation ...")
|
||||
sanity_check_generation(new_model, model_hf, args.model_name, args.prompt, args.max_new_tokens)
|
||||
print("Conversion successful")
|
||||
126
generate.py
126
generate.py
@ -1,126 +0,0 @@
|
||||
#VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3 --load_path smollm.pth
|
||||
import os
|
||||
import argparse
|
||||
import torch, torch.distributed as dist
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoTokenizer
|
||||
|
||||
from utils import set_all_seed
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import PipelineParallel
|
||||
from distributed.distributed_primtives import pipeline_communicate
|
||||
from model import Llama
|
||||
|
||||
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
if pgm.process_group_manager.pp_world_size == 1:
|
||||
return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
|
||||
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
tensor_shapes = (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
# Preallocate memory for output logits.
|
||||
logits = None
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device)
|
||||
|
||||
recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
|
||||
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
|
||||
|
||||
output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
|
||||
|
||||
# Send output to the next stage.
|
||||
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
|
||||
|
||||
# Copy logits.
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
logits = output_tensor
|
||||
|
||||
dist.barrier()
|
||||
|
||||
return logits
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--load_path", type=str)
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--max_tokens", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
|
||||
#TODO(fmom): add gloo backend for generation
|
||||
dist.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1)
|
||||
set_all_seed(seed=42)
|
||||
|
||||
load2name = {
|
||||
"smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
|
||||
"llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
|
||||
}
|
||||
|
||||
config = AutoConfig.from_pretrained(load2name[args.load_path])
|
||||
|
||||
base_model = Llama(config=config, device=device)
|
||||
base_model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
|
||||
model = PipelineParallel(base_model, config).to(device)
|
||||
|
||||
del base_model
|
||||
model.eval()
|
||||
|
||||
# Tokenize the input
|
||||
prompts = [
|
||||
"My name is",
|
||||
"How old are you ?",
|
||||
"What is your favorite color?",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path])
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
|
||||
|
||||
for _ in range(args.max_tokens):
|
||||
|
||||
# Create the batch
|
||||
seq_len = tokenized_prompts["input_ids"].shape[1]
|
||||
position_ids = torch.arange(seq_len).view(1, -1)
|
||||
|
||||
batch_prompts = {
|
||||
"input_ids": tokenized_prompts["input_ids"].to(device=device),
|
||||
"target_ids": None,
|
||||
"position_ids": position_ids.to(device=device),
|
||||
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device),
|
||||
"hidden_states": None,
|
||||
}
|
||||
|
||||
logits = run_one_inference_step(model, batch_prompts, device, config)
|
||||
|
||||
# Sample new token
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
assert logits is not None
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
|
||||
tokenized_prompts["attention_mask"] = torch.cat([tokenized_prompts["attention_mask"], torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device)], dim=-1)
|
||||
else:
|
||||
tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
|
||||
dist.broadcast(tokenized_prompts["input_ids"], src=pgm.process_group_manager.pp_last_rank)
|
||||
dist.broadcast(tokenized_prompts["attention_mask"], src=pgm.process_group_manager.pp_last_rank)
|
||||
|
||||
# Get only the new generated tokens
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
for i, prompt in enumerate(prompts):
|
||||
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
|
||||
outputs = tokenizer.decode(tokenized_outputs)
|
||||
|
||||
print(f"Input: {prompt}")
|
||||
print(f"Output: {outputs}")
|
||||
print("------")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user