Add an option to use dummy model weights (#33)
This commit is contained in:
parent
c267b1a02c
commit
ee88a7e5f3
@ -29,6 +29,7 @@ def main(args: argparse.Namespace):
|
||||
server = Server(
|
||||
model=args.model,
|
||||
model_path=args.model_path,
|
||||
use_dummy_weights=args.use_dummy_weights,
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
block_size=args.block_size,
|
||||
|
||||
@ -47,6 +47,7 @@ class FastAPIFrontend:
|
||||
self.server = remote_server_class.remote(
|
||||
model=model,
|
||||
model_path=model_path,
|
||||
use_dummy_weights=False,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
block_size=block_size,
|
||||
|
||||
@ -16,6 +16,7 @@ class Server:
|
||||
self,
|
||||
model: str,
|
||||
model_path: str,
|
||||
use_dummy_weights: bool,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
block_size: int,
|
||||
@ -66,6 +67,7 @@ class Server:
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
model_path=model_path,
|
||||
use_dummy_weights=use_dummy_weights,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.controllers.append(controller)
|
||||
@ -179,4 +181,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
return parser
|
||||
|
||||
@ -286,3 +286,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
|
||||
return path
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-0.1, 0.1)
|
||||
|
||||
@ -28,18 +28,29 @@ def get_model(
|
||||
model_name: str,
|
||||
dtype: Union[torch.dtype, str],
|
||||
path: str,
|
||||
use_dummy_weights: bool,
|
||||
) -> nn.Module:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.get_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(weights_dir)
|
||||
if use_dummy_weights:
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(config)
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
model.initialize_dummy_weights()
|
||||
else:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.get_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(weights_dir)
|
||||
model = model.cuda()
|
||||
return model.eval(), torch_dtype
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
|
||||
|
||||
@ -324,3 +324,7 @@ class OPTForCausalLM(nn.Module):
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
|
||||
return path
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-0.1, 0.1)
|
||||
|
||||
@ -27,6 +27,7 @@ class Controller:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
model_path: str,
|
||||
use_dummy_weights: bool,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
self.stage_id = stage_id
|
||||
@ -58,6 +59,7 @@ class Controller:
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
model_path=model_path,
|
||||
use_dummy_weights=use_dummy_weights,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -29,6 +29,7 @@ class Worker:
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model_path: str,
|
||||
use_dummy_weights: bool,
|
||||
max_num_batched_tokens: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
@ -43,8 +44,8 @@ class Worker:
|
||||
set_random_seed(seed)
|
||||
|
||||
# Initialize the model.
|
||||
self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
|
||||
self.model = self.model.cuda()
|
||||
self.model, self.dtype = get_model(
|
||||
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
initialize_all_reduce_launcher(
|
||||
|
||||
@ -22,6 +22,7 @@ def main(args: argparse.Namespace):
|
||||
server = Server(
|
||||
model=args.model,
|
||||
model_path=args.model_path,
|
||||
use_dummy_weights=args.use_dummy_weights,
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
block_size=args.block_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user