[v1] reduce graph capture time for piecewise cudagraph (#10059)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
0c63c34f72
commit
c4cacbaa7f
@ -1,7 +1,9 @@
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import operator
|
import operator
|
||||||
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
@ -503,17 +505,29 @@ class PiecewiseBackend:
|
|||||||
entry.input_addresses = input_addresses
|
entry.input_addresses = input_addresses
|
||||||
cudagraph = torch.cuda.CUDAGraph()
|
cudagraph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
# mind-exploding: carefully manage the reference and memory.
|
with ExitStack() as stack:
|
||||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
if not self.is_first_graph:
|
||||||
# `output` is managed by pytorch's cudagraph pool
|
# during every model forward, we will capture
|
||||||
output = entry.runnable(*args)
|
# many pieces of cudagraphs (roughly one per layer).
|
||||||
if self.is_last_graph:
|
# running gc again and again across layers will
|
||||||
# by converting it to weak ref,
|
# make the cudagraph capture very slow.
|
||||||
# the original `output` will immediately be released
|
# therefore, we only run gc for the first graph,
|
||||||
# to save memory. It is only safe to do this for
|
# and disable gc for the rest of the graphs.
|
||||||
# the last graph, because the output of the last graph
|
stack.enter_context(patch("gc.collect", lambda: None))
|
||||||
# will not be used by any other cuda graph.
|
stack.enter_context(
|
||||||
output = weak_ref_tensors(output)
|
patch("torch.cuda.empty_cache", lambda: None))
|
||||||
|
|
||||||
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
|
output = entry.runnable(*args)
|
||||||
|
if self.is_last_graph:
|
||||||
|
# by converting it to weak ref,
|
||||||
|
# the original `output` will immediately be released
|
||||||
|
# to save memory. It is only safe to do this for
|
||||||
|
# the last graph, because the output of the last graph
|
||||||
|
# will not be used by any other cuda graph.
|
||||||
|
output = weak_ref_tensors(output)
|
||||||
|
|
||||||
# here we always use weak ref for the output
|
# here we always use weak ref for the output
|
||||||
# to save memory
|
# to save memory
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user