[v1] reduce graph capture time for piecewise cudagraph (#10059)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-05 18:19:50 -08:00 committed by GitHub
parent 0c63c34f72
commit c4cacbaa7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,9 @@
import copy import copy
import dataclasses import dataclasses
import operator import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import torch import torch
import torch.fx as fx import torch.fx as fx
@ -503,17 +505,29 @@ class PiecewiseBackend:
entry.input_addresses = input_addresses entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph() cudagraph = torch.cuda.CUDAGraph()
# mind-exploding: carefully manage the reference and memory. with ExitStack() as stack:
with torch.cuda.graph(cudagraph, pool=self.graph_pool): if not self.is_first_graph:
# `output` is managed by pytorch's cudagraph pool # during every model forward, we will capture
output = entry.runnable(*args) # many pieces of cudagraphs (roughly one per layer).
if self.is_last_graph: # running gc again and again across layers will
# by converting it to weak ref, # make the cudagraph capture very slow.
# the original `output` will immediately be released # therefore, we only run gc for the first graph,
# to save memory. It is only safe to do this for # and disable gc for the rest of the graphs.
# the last graph, because the output of the last graph stack.enter_context(patch("gc.collect", lambda: None))
# will not be used by any other cuda graph. stack.enter_context(
output = weak_ref_tensors(output) patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output # here we always use weak ref for the output
# to save memory # to save memory