Add an option to use dummy model weights (#33)

This commit is contained in:
Woosuk Kwon 2023-04-08 23:36:12 -07:00 committed by GitHub
parent c267b1a02c
commit ee88a7e5f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 36 additions and 8 deletions

View File

@ -29,6 +29,7 @@ def main(args: argparse.Namespace):
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,

View File

@ -47,6 +47,7 @@ class FastAPIFrontend:
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
use_dummy_weights=False,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,

View File

@ -16,6 +16,7 @@ class Server:
self,
model: str,
model_path: str,
use_dummy_weights: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
@ -66,6 +67,7 @@ class Server:
dtype=dtype,
seed=seed,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.controllers.append(controller)
@ -179,4 +181,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser

View File

@ -286,3 +286,7 @@ class LlamaForCausalLM(nn.Module):
np.save(f, param.cpu().detach().numpy())
return path
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)

View File

@ -28,18 +28,29 @@ def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str,
use_dummy_weights: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
model.initialize_dummy_weights()
else:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')

View File

@ -324,3 +324,7 @@ class OPTForCausalLM(nn.Module):
np.save(f, param.cpu().detach().numpy())
return path
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)

View File

@ -27,6 +27,7 @@ class Controller:
dtype: str,
seed: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
) -> None:
self.stage_id = stage_id
@ -58,6 +59,7 @@ class Controller:
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)

View File

@ -29,6 +29,7 @@ class Worker:
rank: int,
world_size: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
@ -43,8 +44,8 @@ class Worker:
set_random_seed(seed)
# Initialize the model.
self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
self.model = self.model.cuda()
self.model, self.dtype = get_model(
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(

View File

@ -22,6 +22,7 @@ def main(args: argparse.Namespace):
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,