[v1] reduce graph capture time for piecewise cudagraph (#10059)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-05 18:19:50 -08:00 committed by GitHub
parent 0c63c34f72
commit c4cacbaa7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,9 @@
import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import torch
import torch.fx as fx
@ -503,6 +505,18 @@ class PiecewiseBackend:
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool