[torch.compile] remove reset (#7975)
This commit is contained in:
parent
4289cad37f
commit
a7f65c2be9
@ -5,6 +5,10 @@ import tempfile
|
|||||||
|
|
||||||
import depyf
|
import depyf
|
||||||
|
|
||||||
|
# disable custom dispatcher, let Dynamo takes over
|
||||||
|
# all the control
|
||||||
|
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
|
||||||
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
with depyf.prepare_debug(temp_dir):
|
with depyf.prepare_debug(temp_dir):
|
||||||
cur_dir = os.path.dirname(__file__)
|
cur_dir = os.path.dirname(__file__)
|
||||||
@ -16,19 +20,36 @@ with depyf.prepare_debug(temp_dir):
|
|||||||
|
|
||||||
compiled_code = sorted(
|
compiled_code = sorted(
|
||||||
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
|
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
|
||||||
full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0]
|
|
||||||
# we should only trigger Dynamo compilation three times:
|
# we should only trigger Dynamo compilation three times:
|
||||||
# one for the profiling phase (and the compiled artifact will be discarded)
|
# one for the profiling phase without kv cache
|
||||||
# one for the prefill phase with symbolic shapes
|
# one for the prefill phase with symbolic shapes
|
||||||
# one for the decode phase with symbolic shapes
|
# one for the decode phase with symbolic shapes
|
||||||
# and later calls should not trigger Dynamo compilation again.
|
# and later calls should not trigger Dynamo compilation again.
|
||||||
# NOTE: it might still trigger XLA compilation.
|
# NOTE: it might still trigger XLA compilation.
|
||||||
|
|
||||||
# check we have three compiled code
|
# check we have three compiled code
|
||||||
|
# this is the assumption when we use the custom dispatcher
|
||||||
assert len(compiled_code) == 3
|
assert len(compiled_code) == 3
|
||||||
|
|
||||||
# check the first compilation is discarded
|
# check all the compilations are as expected
|
||||||
with open(full_code) as f:
|
compiled_fn = sorted(
|
||||||
full_code_content = f.read()
|
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
|
||||||
profile_function = compiled_code[0].split(".")[0]
|
|
||||||
assert profile_function not in full_code_content
|
# the first compilation is the profiling phase,
|
||||||
|
# it should not have any kv cache
|
||||||
|
with open(compiled_fn[0]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "kv_caches" not in content
|
||||||
|
|
||||||
|
# the second compilation is the prefill phase,
|
||||||
|
# it should have kv cache and the flash_attention op
|
||||||
|
with open(compiled_fn[1]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content
|
||||||
|
|
||||||
|
# the third compilation is the decode phase,
|
||||||
|
# it should have kv cache and the paged_attention op
|
||||||
|
with open(compiled_fn[2]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content
|
||||||
|
|||||||
@ -1123,10 +1123,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# reset and discard the guard and compiled bytecode for profiling runs
|
|
||||||
torch._dynamo.reset()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
|
|||||||
@ -143,10 +143,6 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
block_size_bytes)
|
block_size_bytes)
|
||||||
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||||
|
|
||||||
# reset and discard the guard and compiled bytecode for profiling runs
|
|
||||||
torch._dynamo.reset()
|
|
||||||
|
|
||||||
return num_tpu_blocks, num_cpu_blocks
|
return num_tpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user