picotron/generate.py
2024-10-18 14:33:46 +00:00

126 lines
5.4 KiB
Python

#VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3 --load_path smollm.pth
import os
import argparse
import torch, torch.distributed as dist
from transformers import AutoTokenizer, AutoConfig, AutoTokenizer
from utils import set_all_seed
import distributed.process_group_manager as pgm
from distributed.process_group_manager import setup_process_group_manager
from parallel.pipeline_parallel import PipelineParallel
from distributed.distributed_primtives import pipeline_communicate
from model import Llama
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
if pgm.process_group_manager.pp_world_size == 1:
return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
batch_size = batch["input_ids"].shape[0]
seq_len = batch["input_ids"].shape[1]
tensor_shapes = (batch_size, seq_len, config.hidden_size)
# Preallocate memory for output logits.
logits = None
if pgm.process_group_manager.pp_is_last_stage:
logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device)
recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
# Send output to the next stage.
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
# Copy logits.
if pgm.process_group_manager.pp_is_last_stage:
logits = output_tensor
dist.barrier()
return logits
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load_path", type=str)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--max_tokens", type=int, default=32)
args = parser.parse_args()
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
#TODO(fmom): add gloo backend for generation
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1)
set_all_seed(seed=42)
load2name = {
"smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
"llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
}
config = AutoConfig.from_pretrained(load2name[args.load_path])
base_model = Llama(config=config, device=device)
base_model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
model = PipelineParallel(base_model, config).to(device)
del base_model
model.eval()
# Tokenize the input
prompts = [
"My name is",
"How old are you ?",
"What is your favorite color?",
]
tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path])
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
for _ in range(args.max_tokens):
# Create the batch
seq_len = tokenized_prompts["input_ids"].shape[1]
position_ids = torch.arange(seq_len).view(1, -1)
batch_prompts = {
"input_ids": tokenized_prompts["input_ids"].to(device=device),
"target_ids": None,
"position_ids": position_ids.to(device=device),
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device),
"hidden_states": None,
}
logits = run_one_inference_step(model, batch_prompts, device, config)
# Sample new token
if pgm.process_group_manager.pp_is_last_stage:
assert logits is not None
next_token = torch.argmax(logits[:, -1], dim=-1)
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
tokenized_prompts["attention_mask"] = torch.cat([tokenized_prompts["attention_mask"], torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device)], dim=-1)
else:
tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device)
tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device)
dist.broadcast(tokenized_prompts["input_ids"], src=pgm.process_group_manager.pp_last_rank)
dist.broadcast(tokenized_prompts["attention_mask"], src=pgm.process_group_manager.pp_last_rank)
# Get only the new generated tokens
if pgm.process_group_manager.pp_is_last_stage:
for i, prompt in enumerate(prompts):
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
outputs = tokenizer.decode(tokenized_outputs)
print(f"Input: {prompt}")
print(f"Output: {outputs}")
print("------")