[v1] reduce graph capture time for piecewise cudagraph (#10059)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
0c63c34f72
commit
c4cacbaa7f
@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import operator
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -503,17 +505,29 @@ class PiecewiseBackend:
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
|
||||
Loading…
Reference in New Issue
Block a user