From 1ce13335732c8b28d4a76118821b391f6b219b7c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 21:31:39 +0000 Subject: [PATCH] Set default dtype to half --- cacheflow/models/model_utils.py | 21 +++++++++++++++++++-- cacheflow/worker/controller.py | 2 ++ cacheflow/worker/worker.py | 3 ++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 522630d7..80b7c010 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,3 +1,6 @@ +from typing import Union + +import torch import torch.nn as nn from cacheflow.models.opt import OPTForCausalLM @@ -6,9 +9,23 @@ MODEL_CLASSES = { 'opt': OPTForCausalLM, } +STR_DTYPE_TO_TORCH_DTYPE = { + 'half': torch.half, + 'float': torch.float, + 'float16': torch.float16, + 'float32': torch.float32, +} -def get_model(model_name: str) -> nn.Module: + +def get_model( + model_name: str, + dtype: Union[torch.dtype, str], +) -> nn.Module: + if isinstance(dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] + else: + torch_dtype = dtype for model_class, model in MODEL_CLASSES.items(): if model_class in model_name: - return model.from_pretrained(model_name) + return model.from_pretrained(model_name, torch_dtype=torch_dtype) raise ValueError(f'Invalid model name: {model_name}') diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index d6d3b783..f75804f1 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -14,6 +14,7 @@ class Controller: block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + dtype: str = 'half', ) -> None: self.node_id = node_id self.num_workers = num_workers @@ -35,6 +36,7 @@ class Controller: block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, + dtype=dtype, ) self.workers.append(worker) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 3a0600c2..9b5f4661 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -17,6 +17,7 @@ class Worker: block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + dtype: str, ) -> None: self.worker_id = worker_id self.gpu_id = gpu_id @@ -26,7 +27,7 @@ class Worker: # Initialize the model. # FIXME(woosuk): This is a hack. - self.model = get_model(model_name).to(device=gpu_id) + self.model = get_model(model_name, dtype=dtype).to(device=self.device) self.num_layers = self.model.config.num_hidden_layers self.num_heads = self.model.config.num_attention_heads self.head_size = self.model.config.hidden_size // self.num_heads