status);
}
+inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
+ auto status = currentStreamCaptureStatus();
+ TORCH_CHECK(status == CaptureStatus::None,
+ "Current cudaStreamCaptureStatus: ",
+ status,
+ "\nCapturing ",
+ version_specific,
+ "is prohibited. Possible causes of this error:\n"
+ "1. No warmup iterations occurred before capture.\n"
+ "2. The convolutions you're trying to capture use dynamic shapes, "
+ "in which case capturing them is generally prohibited.");
+}
+
} // namespace cuda
} // namespace at
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
+#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/Exceptions.h>
#include <ATen/native/cudnn/ConvShared.h>
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
+ at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
+ at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardDataAlgorithmEx(
args.handle,
args.wdesc.desc(), args.weight.data_ptr(),
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
+ at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardFilterAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
} else {
// It's ok to capture cudaMallocs, as long as we never cudaFree those
// addresses before replay.
+ // Capturing cudaMalloc behaves nicely: it gives the graph new VA,
+ // but is ignored (won't leakily allocate new memory) in replays.
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
return cudaMalloc(p, size);
}
Stream
Event
+Graphs (prototype)
+------------------
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ graph_pool_handle
+ CUDAGraph
+ graph
+ make_graphed_callables
+
Memory management
-----------------
.. autosummary::
BC note: Using grads on the default stream
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-In prior versions of Pytorch (1.9 and earlier), the autograd engine always synced
+In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced
the default stream with all backward ops, so the following pattern::
with torch.cuda.stream(s):
use grads
was safe as long as ``use grads`` happened on the default stream.
-In present Pytorch, that pattern is no longer safe. If ``backward()``
+In present PyTorch, that pattern is no longer safe. If ``backward()``
and ``use grads`` are in different stream contexts, you must sync the streams::
with torch.cuda.stream(s):
If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use
`torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`.
+
+.. _cuda-graph-semantics:
+
+CUDA Graphs
+-----------
+
+A CUDA graph is a record of the work (mostly kernels and their arguments) that a
+CUDA stream and its dependent streams perform.
+For general principles and details on the underlying CUDA API, see
+`Getting Started with CUDA Graphs`_ and the
+`Graphs section`_ of the CUDA C Programming Guide.
+
+PyTorch supports the construction of CUDA graphs using `stream capture`_, which puts a
+CUDA stream in *capture mode*. CUDA work issued to a capturing stream doesn't actually
+run on the GPU. Instead, the work is recorded in a graph.
+
+After capture, the graph can be *launched* to run the GPU work as many times as needed.
+Each replay runs the same kernels with the same arguments. For pointer arguments this
+means the same memory addresses are used.
+By filling input memory with new data (e.g., from a new batch) before each replay,
+you can rerun the same work on new data.
+
+Why CUDA Graphs?
+^^^^^^^^^^^^^^^^
+
+Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange for
+**greatly reduced CPU overhead**. A graph's arguments and kernels are fixed, so a graph replay
+skips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driver
+overheads. Under the hood, a replay submits the entire graph's work to the GPU with
+a single call to `cudaGraphLaunch`_. Kernels in a replay also execute slightly faster
+on the GPU, but eliding CPU overhead is the main benefit.
+
+You should try CUDA graphs if all or part of your network is graph-safe (usually this means
+static shapes and static control flow, but see the other :ref:`constraints<capture-constraints>`)
+and you suspect its runtime is at least somewhat CPU-limited.
+
+.. _Getting Started with CUDA Graphs:
+ https://developer.nvidia.com/blog/cuda-graphs/
+.. _Graphs section:
+ https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs
+.. _stream capture:
+ https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
+.. _cudaGraphLaunch:
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
+
+PyTorch API
+^^^^^^^^^^^
+
+.. warning::
+ This API is a prototype and may change in future releases.
+
+PyTorch exposes graphs via a raw :class:`torch.cuda.CUDAGraph` class
+and two convenience wrappers,
+:class:`torch.cuda.graph` and
+:class:`torch.cuda.make_graphed_callables`.
+
+:class:`torch.cuda.graph` is a simple, versatile context manager that
+captures CUDA work in its context.
+Before capture, warm up the workload to be captured by running
+a few eager iterations. Warmup must occur on a side stream.
+Because the graph reads from and writes to the same memory addresses in every
+replay, you must maintain long-lived references to tensors that hold
+input and output data during capture.
+To run the graph on new input data, copy new data to the capture's input tensor(s),
+replay the graph, then read the new output from the capture's output tensor(s).
+Example::
+
+ g = torch.cuda.CUDAGraph()
+
+ # Placeholder input used for capture
+ static_input = torch.empty((5,), device="cuda")
+
+ # Warmup before capture
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for _ in range(3):
+ static_output = static_input * 2
+ torch.cuda.current_stream().wait_stream(s)
+
+ # Captures the graph
+ # To allow capture, automatically sets a side stream as the current stream in the context
+ with torch.cuda.graph(g):
+ static_output = static_input * 2
+
+ # Fills the graph's input memory with new data to compute on
+ static_input.copy_(torch.full((5,), 3, device="cuda"))
+ g.replay()
+ # static_output holds the results
+ print(static_output) # full of 3 * 2 = 6
+
+ # Fills the graph's input memory with more data to compute on
+ static_input.copy_(torch.full((5,), 4, device="cuda"))
+ g.replay()
+ print(static_output) # full of 4 * 2 = 8
+
+See
+:ref:`Whole-network capture<whole-network-capture>`,
+:ref:`Usage with torch.cuda.amp<graphs-with-amp>`, and
+:ref:`Usage with multiple streams<multistream-capture>`
+for realistic and advanced patterns.
+
+:class:`~torch.cuda.make_graphed_callables` is more sophisticated.
+:class:`~torch.cuda.make_graphed_callables` accepts Python functions and
+:class:`torch.nn.Module`\s. For each passed function or Module,
+it creates separate graphs of the forward-pass and backward-pass work. See
+:ref:`Partial-network capture<partial-network-capture>`.
+
+.. _capture-constraints:
+
+Constraints
+~~~~~~~~~~~
+
+A set of ops is *capturable* if it doesn't violate any of the following constraints.
+
+Constraints apply to all work in a
+:class:`torch.cuda.graph` context and all work in the forward and backward passes
+of any callable you pass to :func:`torch.cuda.make_graphed_callables`.
+
+Violating any of these will likely cause a runtime error:
+
+* Capture must occur on a non-default stream. (This is only a concern if you use the raw
+ :meth:`CUDAGraph.capture_begin<torch.cuda.CUDAGraph.capture_begin>` and
+ :meth:`CUDAGraph.capture_end<torch.cuda.CUDAGraph.capture_end>` calls.
+ :class:`~torch.cuda.graph` and
+ :func:`~torch.cuda.make_graphed_callables` set a side stream for you.)
+* Ops that sychronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited.
+* CUDA RNG ops are allowed, but must use default generators. For example, explicitly constructing a
+ new :class:`torch.Generator` instance and passing it as the ``generator`` argument to an RNG function
+ is prohibited.
+
+Violating any of these will likely cause silent numerical errors or undefined behavior:
+
+* Within a process, only one capture may be underway at a time.
+* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
+* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
+* Every replay reads from and writes to the same (virtual) memory addresses.
+* Dynamic control flow (based on CPU or GPU data) is prohibited.
+* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
+ has the same size and layout in every replay.
+* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
+
+Non-constraints
+~~~~~~~~~~~~~~~
+
+* Once captured, the graph may be replayed on any stream.
+
+.. _whole-network-capture:
+
+Whole-network capture
+^^^^^^^^^^^^^^^^^^^^^^
+
+If your entire network is capturable, you can capture and replay an entire iteration::
+
+ N, D_in, H, D_out = 640, 4096, 2048, 1024
+ model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
+ torch.nn.Dropout(p=0.2),
+ torch.nn.Linear(H, D_out),
+ torch.nn.Dropout(p=0.1)).cuda()
+ loss_fn = torch.nn.MSELoss()
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
+
+ # Placeholders used for capture
+ static_input = torch.randn(N, D_in, device='cuda')
+ static_target = torch.randn(N, D_out, device='cuda')
+
+ # warmup
+ # Uses static_input and static_target here for convenience,
+ # but in a real setting, because the warmup includes optimizer.step()
+ # you must use a few batches of real data.
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ optimizer.zero_grad(set_to_none=True)
+ y_pred = model(static_input)
+ loss = loss_fn(y_pred, static_target)
+ loss.backward()
+ optimizer.step()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ # Sets grads to None before capture, so backward() will create
+ # .grad attributes with allocations from the graph's private pool
+ optimizer.zero_grad(set_to_none=True)
+ with torch.cuda.graph(g):
+ static_y_pred = model(static_input)
+ static_loss = loss_fn(static_y_pred, static_target)
+ static_loss.backward()
+ optimizer.step()
+
+ real_inputs = [torch.rand_like(static_input) for _ in range(10)]
+ real_targets = [torch.rand_like(static_target) for _ in range(10)]
+
+ for data, target in zip(real_inputs, real_targets):
+ # Fills the graph's input memory with new data to compute on
+ static_input.copy_(data)
+ static_target.copy_(target)
+ # replay() includes forward, backward, and step.
+ # You don't even need to call optimizer.zero_grad() between iterations
+ # because the captured backward refills static .grad tensors in place.
+ g.replay()
+ # Params have been updated. static_y_pred, static_loss, and .grad
+ # attributes hold values from computing on this iteration's data.
+
+.. _partial-network-capture:
+
+Partial-network capture
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If some of your network is unsafe to capture (e.g., due to dynamic control flow,
+dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe
+part(s) eagerly and use :func:`torch.cuda.make_graphed_callables` to graph only
+the capture-safe part(s).
+
+By default, callables returned by :func:`~torch.cuda.make_graphed_callables`
+are autograd-aware, and can be used in the training loop as direct replacements
+for the functions or :class:`nn.Module<torch.nn.Module>`\ s you passed.
+
+:func:`~torch.cuda.make_graphed_callables` internally creates
+:class:`~torch.cuda.CUDAGraph` objects, runs warmup iterations, and maintains
+static inputs and outputs as needed. Therefore (unlike with
+:class:`torch.cuda.graph`) you don't need to handle those manually.
+
+In the following example, data-dependent dynamic control flow means the
+network isn't capturable end-to-end, but
+:func:`~torch.cuda.make_graphed_callables`
+lets us capture and run graph-safe sections as graphs regardless::
+
+ N, D_in, H, D_out = 640, 4096, 2048, 1024
+
+ module1 = torch.nn.Linear(D_in, H).cuda()
+ module2 = torch.nn.Linear(H, D_out).cuda()
+ module3 = torch.nn.Linear(H, D_out).cuda()
+
+ loss_fn = torch.nn.MSELoss()
+ optimizer = torch.optim.SGD(chain(module1.parameters() +
+ module2.parameters() +
+ module3.parameters()),
+ lr=0.1)
+
+ # Sample inputs used for capture
+ # requires_grad state of sample inputs must match
+ # requires_grad state of real inputs each callable will see.
+ x = torch.randn(N, D_in, device='cuda')
+ h = torch.randn(N, H, device='cuda', requires_grad=True)
+
+ module1 = torch.cuda.make_graphed_callables(module1, (x,))
+ module2 = torch.cuda.make_graphed_callables(module2, (h,))
+ module3 = torch.cuda.make_graphed_callables(module3, (h,))
+
+ real_inputs = [torch.rand_like(x) for _ in range(10)]
+ real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]
+
+ for data, target in zip(real_inputs, real_targets):
+ optimizer.zero_grad(set_to_none=True)
+
+ tmp = module1(data) # forward ops run as a graph
+
+ if tmp.sum().item() > 0:
+ tmp = module2(tmp) # forward ops run as a graph
+ else:
+ tmp = module3(tmp) # forward ops run as a graph
+
+ loss = loss_fn(tmp, y)
+ # module2's or module3's (whichever was chosen) backward ops,
+ # as well as module1's backward ops, run as graphs
+ loss.backward()
+ optimizer.step()
+
+.. _graphs-with-amp:
+
+Usage with torch.cuda.amp
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For typical optimizers, :meth:`GradScaler.step<torch.cuda.amp.GradScaler.step>` syncs
+the CPU with the GPU, which is prohibited during capture. To avoid errors, either use
+:ref:`partial-network capture<partial-network-capture>`, or (if forward, loss,
+and backward are capture-safe) capture forward, loss, and backward but not the
+optimizer step::
+
+ # warmup
+ # In a real setting, use a few batches of real data.
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ optimizer.zero_grad(set_to_none=True)
+ with torch.cuda.amp.autocast():
+ y_pred = model(static_input)
+ loss = loss_fn(y_pred, static_target)
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ optimizer.zero_grad(set_to_none=True)
+ with torch.cuda.graph(g):
+ with torch.cuda.amp.autocast():
+ static_y_pred = model(static_input)
+ static_loss = loss_fn(static_y_pred, static_target)
+ scaler.scale(static_loss).backward()
+ # don't capture scaler.step(optimizer) or scaler.update()
+
+ real_inputs = [torch.rand_like(static_input) for _ in range(10)]
+ real_targets = [torch.rand_like(static_target) for _ in range(10)]
+
+ for data, target in zip(real_inputs, real_targets):
+ static_input.copy_(data)
+ static_target.copy_(target)
+ g.replay()
+ # Runs scaler.step and scaler.update eagerly
+ scaler.step(optimizer)
+ scaler.update()
+
+.. _multistream-capture:
+
+Usage with multiple streams
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Capture mode automatically propagates to any streams that sync with a capturing stream.
+Within capture, you may expose parallelism by issuing calls to different streams,
+but the overall stream dependency DAG must branch out from the
+initial capturing stream after capture begins and rejoin the initial stream
+before capture ends::
+
+ with torch.cuda.graph(g):
+ # at context manager entrance, torch.cuda.current_stream()
+ # is the initial capturing stream
+
+ # INCORRECT (does not branch out from or rejoin initial stream)
+ with torch.cuda.stream(s):
+ cuda_work()
+
+ # CORRECT:
+ # branches out from initial stream
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ cuda_work()
+ # rejoins initial stream before capture ends
+ torch.cuda.current_stream().wait_stream(s)
+
+.. note::
+
+ To avoid confusion for power users looking at replays in nsight systems or nvprof:
+ Unlike eager execution, the graph interprets a nontrivial stream DAG in capture
+ as a hint, not a command. During replay, the graph may reorganize independent ops
+ onto different streams or enqueue them in a different order (while respecting your
+ original DAG's overall dependencies).
+
+Usage with DistributedDataParallel
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+NCCL < 2.9.6
+~~~~~~~~~~~~
+
+NCCL versions earlier than 2.9.6 don't allow collectives to be captured.
+You must use :ref:`partial-network capture<partial-network-capture>`,
+which defers allreduces to happen outside graphed sections of backward.
+
+Call :func:`~torch.cuda.make_graphed_callables` on graphable network sections
+*before* wrapping the network with DDP.
+
+NCCL >= 2.9.6
+~~~~~~~~~~~~~
+
+NCCL versions 2.9.6 or later allow collectives in the graph.
+Approaches that capture an :ref:`entire backward pass<whole-network-capture>`
+are a viable option, but need three setup steps.
+
+1. Disable DDP's internal async error handling::
+
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
+ torch.distributed.init_process_group(...)
+
+2. Before full-backward capture, DDP must be constructed in a side-stream context::
+
+ with torch.cuda.stream(s):
+ model = DistributedDataParallel(model)
+
+3. Your warmup must run at least 11 DDP-enabled eager iterations before capture.
+
+.. _graph-memory-management:
+
+Graph memory management
+^^^^^^^^^^^^^^^^^^^^^^^
+
+A captured graph acts on the same virtual addresses every time it replays.
+If PyTorch frees the memory, a later replay can hit an illegal memory access.
+If PyTorch reassigns the memory to new tensors, the replay can corrupt the values
+seen by those tensors. Therefore, the virtual addresses used by the graph must be
+reserved for the graph across replays. The PyTorch caching allocator achieves this
+by detecting when capture is underway and satisfying the capture's allocations
+from a graph-private memory pool. The private pool stays alive until its
+:class:`~torch.cuda.CUDAGraph` object and all tensors created during capture
+go out of scope.
+
+Private pools are maintained automatically. By default, the allocator creates a
+separate private pool for each capture. If you capture multiple graphs,
+this conservative approach ensures graph replays never corrupt each other's values,
+but sometimes needlessly wastes memory.
+
+To economize the memory stashed in private pools, :class:`torch.cuda.graph`
+and :func:`torch.cuda.make_graphed_callables` optionally allow different
+captures to share the same private pool.
+It's safe for a set of graphs to share a private pool if you know they'll always
+be replayed in the same order they were captured,
+and never be replayed concurrently.
+
+Sharing memory across captures with torch.cuda.graph
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+:class:`torch.cuda.graph`'s ``pool`` argument is a hint to use a particular private pool,
+and can be used to share memory across graphs as shown::
+
+ g1 = torch.cuda.CUDAGraph()
+ g2 = torch.cuda.CUDAGraph()
+
+ # (create static inputs for g1 and g2, run warmups of their workloads...)
+
+ # Captures g1
+ with torch.cuda.graph(g1):
+ static_out_1 = g1_workload(static_in_1)
+
+ # Captures g2, hinting that g2 may share a memory pool with g1
+ with torch.cuda.graph(g2, pool=g1.pool()):
+ static_out_2 = g2_workload(static_in_2)
+
+ static_in_1.copy_(real_data_1)
+ static_in_2.copy_(real_data_2)
+ g1.replay()
+ g2.replay()
+
+Sharing memory across captures with torch.cuda.make_graphed_callables
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+With :func:`torch.cuda.make_graphed_callables`, if you want to graph several
+callables and you know they'll always run in the same order (and never concurrently)
+pass them as a tuple in the same order they'll run in the live workload, and
+:func:`~torch.cuda.make_graphed_callables` will capture their graphs using a shared
+private pool.
+
+If, in the live workload, your callables will run in an order that occasionally changes,
+or if they'll run concurrently, passing them as a tuple to a single invocation of
+:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
+:func:`~torch.cuda.make_graphed_callables` separately for each one.
with torch.cuda.stream(s):
a = torch.full((1000,), 1, device="cuda")
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
torch.cuda.empty_cache()
g.capture_begin()
b = a
with torch.cuda.stream(stream):
torch.cuda.manual_seed(5)
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
torch.cuda.empty_cache()
g.capture_begin()
graph_out = graph_in
with torch.cuda.stream(stream):
torch.cuda.manual_seed(5)
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
torch.cuda.empty_cache()
if (module == "torch"):
g.capture_begin()
s = torch.cuda.Stream()
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
- g0 = torch.cuda._Graph()
- g1 = torch.cuda._Graph()
+ g0 = torch.cuda.CUDAGraph()
+ g1 = torch.cuda.CUDAGraph()
a = torch.ones((size,), device="cuda")
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
- g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+ g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
g0.capture_begin(*g0_args)
b = a.clone()
for _ in range(5):
s = torch.cuda.Stream()
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
- g0 = torch.cuda._Graph()
- g1 = torch.cuda._Graph()
+ g0 = torch.cuda.CUDAGraph()
+ g1 = torch.cuda.CUDAGraph()
s0 = torch.cuda.Stream()
s1 = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
- g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+ g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
g0.capture_begin(*g0_args)
b = a.clone()
for _ in range(5):
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
a = torch.ones((size,), device="cuda")
- g0 = torch.cuda._Graph()
- g1 = torch.cuda._Graph()
- g2 = torch.cuda._Graph()
+ g0 = torch.cuda.CUDAGraph()
+ g1 = torch.cuda.CUDAGraph()
+ g2 = torch.cuda.CUDAGraph()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
- g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+ g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
g0.capture_begin(*g0_args)
b = a.clone()
c = b + 1
delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder
delta_active_bytes = numel * elem
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
# Allocation stat estimates assume input is created on the same stream as capture_begin()
s0 = torch.cuda.Stream()
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
with torch.cuda.stream(s0):
y = model(x)
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
torch.cuda.empty_cache()
scaler = torch.cuda.amp.GradScaler(init_scale=4.)
- g = torch.cuda._Graph()
+ g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
weight = torch.ones((100,), device="cuda", requires_grad=True)
static_input = torch.ones_like(weight)
static_grad = torch.ones_like(weight)
- s.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(s):
- # warmup
- loss = (weight.half() * static_input).sum()
- scaler.scale(loss).backward()
- opt.zero_grad(set_to_none=True)
- # capture
- g.capture_begin()
+ # warmup
+ loss = (weight.half() * static_input).sum()
+ scaler.scale(loss).backward()
+ opt.zero_grad(set_to_none=True)
+
+ # capture
+ with torch.cuda.graph(g):
loss = (weight.half() * static_input).sum()
scaler.scale(loss).backward()
- g.capture_end()
- torch.cuda.current_stream().wait_stream(s)
input_vals = [5, 20000, 5, 40000]
# If the scale gets updated properly, these are the scale, growth tracker,
self.assertEqual(scaler._scale, scale)
self.assertEqual(scaler._growth_tracker, growth_tracker)
+ @unittest.skipIf((not TEST_CUDA) or
+ TEST_WITH_ROCM or
+ int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
+ def test_graph_make_graphed_callables(self):
+ torch.manual_seed(5)
+ torch.cuda.manual_seed(5)
+
+ N, D_in, H, D_out = 640, 4096, 2048, 1024
+
+ models = []
+ for _ in range(2):
+ model_section1 = torch.nn.Sequential(torch.nn.Linear(D_in, H),
+ torch.nn.Dropout(p=0.1)).cuda()
+ model_section2 = torch.nn.Sequential(torch.nn.Linear(H, D_out),
+ torch.nn.Dropout(p=0.2)).cuda()
+ models.append(torch.nn.Sequential(model_section1, model_section2))
+
+ model_graphed = models[0]
+ model_control = models[1]
+
+ model_graphed.load_state_dict(model_control.state_dict())
+
+ opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
+ opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
+
+ x = torch.randn(N, D_in, device='cuda')
+ h = torch.randn(N, H, device='cuda', requires_grad=True)
+ y_pred = torch.randn(N, D_out, device='cuda', requires_grad=True)
+ y = torch.randn(N, D_out, device='cuda')
+
+ loss_fn_control = torch.nn.functional.mse_loss
+ relu_control = torch.nn.functional.relu
+
+ # This is a good stress test. It graphs four callables: two Modules and two python functions.
+ model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \
+ torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control),
+ ((x,), (h,), (y_pred,), (y_pred, y)))
+
+ real_inputs = [torch.rand_like(x) for _ in range(10)]
+ real_targets = [torch.rand_like(y) for _ in range(10)]
+
+ for m, opt, relu, loss_fn in zip((model_graphed, model_control),
+ (opt_graphed, opt_control),
+ (relu_graphed, relu_control),
+ (loss_fn_graphed, loss_fn_control)):
+ # Resets RNC states before iterations for graphed and ungraphed models,
+ # so dropout math should be bitwise identical for both.
+ torch.manual_seed(5)
+ torch.cuda.manual_seed(5)
+ for data, target in zip(real_inputs, real_targets):
+ opt.zero_grad(set_to_none=True)
+ y_pred = m(data)
+ y_pred = relu(y_pred)
+ loss = loss_fn(y_pred, target)
+ loss.backward()
+ opt.step()
+
+ for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
+ self.assertEqual(p, pc)
+
+ # We graphed the models in training mode. Eval should still run ungraphed.
+ model_graphed.eval()
+ model_control.eval()
+ self.assertEqual(model_graphed(real_inputs[0]), model_control(real_inputs[0]))
+
def test_batch_norm_gather_stats(self):
input = torch.randn(1, 3, 3, 3, device='cuda')
mean, invstd = torch.batch_norm_gather_stats(
def ipc_handle(self) -> bytes: ...
# Defined in torch/csrc/cuda/Graph.cpp
-class _CudaGraphBase:
- ...
+class _CUDAGraph:
+ def capture_begin(self,
+ pool: Optional[Tuple[_int, _int]]=...) -> None: ...
+ def capture_end(self) -> None: ...
+ def replay(self) -> None: ...
+ def reset(self) -> None: ...
+ def pool(self) -> Tuple[_int, _int]: ...
def _graph_pool_handle() -> Tuple[_int, _int]: ...
auto torch_C_m = py::handle(module).cast<py::module>();
torch_C_m
- .def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
+ .def("_graph_pool_handle",
+ &::at::cuda::graph_pool_handle);
- shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CudaGraphBase")
+ shared_ptr_class_<::at::cuda::CUDAGraph>
+ (torch_C_m,
+ "_CUDAGraph")
.def(py::init<>())
// I'm not sure this is the correct order of all the arguments. Pybind11 docs
// aren't clear. But it works.
.def("capture_begin",
&::at::cuda::CUDAGraph::capture_begin,
py::call_guard<py::gil_scoped_release>(),
- R"(``capture_begin`` begins Cuda graph capture on the current stream.)",
py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
.def("capture_end",
&::at::cuda::CUDAGraph::capture_end,
- py::call_guard<py::gil_scoped_release>(),
- R"(``capture_end`` ends Cuda graph capture on the current stream.
- After ``capture_end``, ``replay`` may be called on this instance.)")
+ py::call_guard<py::gil_scoped_release>())
.def("replay",
&::at::cuda::CUDAGraph::replay,
- py::call_guard<py::gil_scoped_release>(),
- R"(``replay`` replays the Cuda graph captured by this instance.)")
- // reset is called in __del__ on the Python side
- // (see class Graph in torch/cuda/streams.py for reasons and caveats)
+ py::call_guard<py::gil_scoped_release>())
.def("reset",
&::at::cuda::CUDAGraph::reset,
- py::call_guard<py::gil_scoped_release>(),
- R"(``reset`` deletes the graph currently held by this instance.)")
+ py::call_guard<py::gil_scoped_release>())
.def("pool",
&::at::cuda::CUDAGraph::pool,
- py::call_guard<py::gil_scoped_release>(),
- R"(``pool`` retrieves the id of this graph's memory pool.
- This id can optionally be passed to another graph's capture_begin,
- which hints that other graph may share the same memory pool.)");
+ py::call_guard<py::gil_scoped_release>());
}
import threading
from typing import List, Optional, Tuple, Union, Any
from ._utils import _get_device_index, _dummy_type
-from .streams import Stream, Event, _Graph, _graph_pool_handle
+from .graphs import CUDAGraph, graph_pool_handle, graph, make_graphed_callables
+from .streams import Stream, Event
from .. import device as _device
import torch._C
--- /dev/null
+import gc
+import torch
+
+from ._utils import _dummy_type
+
+
+if not hasattr(torch._C, '_CudaStreamBase'):
+ # Define dummy base classes
+ torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph')
+ torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
+
+from torch._C import _CUDAGraph # noqa: F401
+from torch._C import _graph_pool_handle
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+def graph_pool_handle():
+ r"""
+ Returns an opaque token representing the id of a graph memory pool.
+ See :ref:`Graph memory management<graph-memory-management>`.
+
+ .. warning::
+ This API is a prototype and may change in future releases.
+ """
+ return _graph_pool_handle()
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+class CUDAGraph(torch._C._CUDAGraph):
+ r"""
+ Wrapper around a CUDA graph.
+
+ .. warning::
+ This API is a prototype and may change in future releases.
+ """
+ def __new__(cls):
+ return super(CUDAGraph, cls).__new__(cls)
+
+ def __init__(self):
+ super(CUDAGraph, self).__init__()
+
+ def capture_begin(self, pool=None):
+ r"""
+ Begins capturing CUDA work on the current stream.
+
+ Typically, you shouldn't call ``capture_begin`` yourself.
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+ which call ``capture_begin`` internally.
+
+ Arguments:
+ pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
+ with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
+ """
+ # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
+ # so I'm not taking any chances.
+ if pool is None:
+ super(CUDAGraph, self).capture_begin()
+ else:
+ super(CUDAGraph, self).capture_begin(pool)
+
+ def capture_end(self):
+ r"""
+ Ends CUDA graph capture on the current stream.
+ After ``capture_end``, ``replay`` may be called on this instance.
+
+ Typically, you shouldn't call ``capture_end`` yourself.
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+ which call ``capture_end`` internally.
+ """
+ super(CUDAGraph, self).capture_end()
+
+ def replay(self):
+ r"""
+ Replays the CUDA work captured by this graph.
+ """
+ super(CUDAGraph, self).replay()
+
+ def reset(self):
+ r"""
+ Deletes the graph currently held by this instance.
+ """
+ super(CUDAGraph, self).reset()
+
+ def pool(self):
+ r"""
+ Returns an opaque token representing the id of this graph's memory pool.
+ This id can optionally be passed to another graph's ``capture_begin``,
+ which hints the other graph may share the same memory pool.
+ """
+ return super(CUDAGraph, self).pool()
+
+
+class graph(object):
+ r"""
+ Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph`
+ object for later replay.
+
+ See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
+ detailed use, and constraints.
+
+ Arguments:
+ cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
+ pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
+ may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
+ stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
+ If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
+
+ .. note::
+ For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
+ used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
+
+ .. warning::
+ This API is a prototype and may change in future releases.
+ """
+ default_capture_stream = None
+
+ def __init__(self,
+ cuda_graph,
+ pool=None,
+ stream=None):
+ # Lazy-init of default_capture_stream helps avoid circular-import errors.
+ # Not thread safe, but graphs already have the general (explicitly documented)
+ # restriction that only one capture may be underway at a time in the process.
+ if self.__class__.default_capture_stream is None:
+ self.__class__.default_capture_stream = torch.cuda.Stream()
+
+ self.pool = () if pool is None else (pool,)
+ self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream
+ assert self.capture_stream is not None
+ self.stream_ctx = torch.cuda.stream(self.capture_stream)
+ self.cuda_graph = cuda_graph
+
+ def __enter__(self):
+ # Free as much memory as we can for the graph
+ torch.cuda.synchronize()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Stackoverflow seems comfortable with this pattern
+ # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
+ self.stream_ctx.__enter__()
+
+ self.cuda_graph.capture_begin(*self.pool)
+
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.cuda_graph.capture_end()
+ self.stream_ctx.__exit__(exc_type, exc_value, traceback)
+ # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
+
+
+def make_graphed_callables(callables, sample_args):
+ r"""
+ Accepts callables (functions or :class:`nn.Module<torch.nn.Module>`\ s)
+ and returns graphed versions.
+
+ Each graphed callable's forward pass runs its source callable's
+ forward CUDA work as a CUDA graph inside a single autograd node.
+
+ The graphed callable's forward pass also appends
+ a backward node to the autograd graph. During backward, this node runs the
+ callable's backward work as a CUDA graph.
+
+ Therefore, each graphed callable should be a drop-in replacement for its source callable
+ in an autograd-enabled training loop.
+
+ See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
+
+ If you pass a tuple of several callables, their captures will use the same memory pool.
+ See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
+
+ Arguments:
+ callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
+ See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
+ is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
+ they'll run in the live workload.
+ sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
+ If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
+ If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
+
+ .. note::
+ The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
+ that's expected for the corresponding real input in the training loop.
+
+ .. warning::
+ This API is a prototype and may change in future releases.
+
+ .. warning::
+ ``sample_args`` for each callable must be a tuple of Tensors. Other types and keyword args
+ are not allowed.
+
+ .. warning::
+ Returned callables do not support higher order differentiation (e.g., double backward).
+
+ .. warning::
+ In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
+ may be trainable. Buffers must have ``requires_grad=False``.
+
+ .. warning::
+ After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
+ you may not add or remove any of that Module's parameters or buffers.
+
+ .. warning::
+ :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
+ registered on them at the time they are passed. However, registering hooks on modules *after* passing them
+ through :func:`~torch.cuda.make_graphed_callables` is allowed.
+
+ .. warning::
+ When running a graphed callable, you must pass its arguments in the same order and format
+ they appeared in that callable's ``sample_args``.
+
+ .. warning::
+ All Tensor outputs of graphed callables must require grad.
+ """
+ just_one_callable = False
+
+ if not isinstance(callables, tuple):
+ just_one_callable = True
+ callables = (callables,)
+ sample_args = (sample_args,)
+
+ for c, args in zip(callables, sample_args):
+ if isinstance(c, torch.nn.Module):
+ assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \
+ "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \
+ "on modules after passing them through make_graphed_callables is allowed."
+ assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \
+ ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \
+ "``requires_grad=False``."
+ assert all(isinstance(arg, torch.Tensor) for arg in args), "In the prototype API, sample_args " + \
+ "for each callable must be a tuple of Tensors. Other types and keyword args are not allowed."
+
+
+ # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
+ # passes to forward (ie, its sample_args) AND the module's parameter attributes.
+ per_callable_len_user_args = [len(args) for args in sample_args]
+ per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
+ for c in callables]
+ per_callable_static_input_surfaces = [sample_args[i] + per_callable_module_params[i]
+ for i in range(len(callables))]
+
+ fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+ bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+
+ mempool = graph_pool_handle()
+
+ # Warmup
+ # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+ # from ending up in any captures.
+ torch.cuda.synchronize()
+ with torch.cuda.stream(torch.cuda.Stream()):
+ for func, args, static_input_surface in zip(callables,
+ sample_args,
+ per_callable_static_input_surfaces):
+ for _ in range(3):
+ outputs = func(*args)
+ outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
+ grad_inputs = torch.autograd.grad(outputs=outputs,
+ inputs=tuple(i for i in static_input_surface if i.requires_grad),
+ grad_outputs=tuple(torch.empty_like(o) for o in outputs),
+ only_inputs=True,
+ allow_unused=False)
+ del outputs, grad_inputs
+ torch.cuda.synchronize()
+
+ # All captures here share a mempool. To avoid replays corrupting each other's memory,
+ # the safest approach is to capture all passes in the same order they'll run:
+ # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
+
+ # Capture forward graphs
+ per_callable_static_outputs = []
+ per_callable_output_was_tensor = []
+ for func, args, fwd_graph in zip(callables,
+ sample_args,
+ fwd_graphs):
+ with torch.cuda.graph(fwd_graph, pool=mempool):
+ outputs = func(*args)
+
+ # Assumes model output is a tensor or tuple of tensors
+ if isinstance(outputs, torch.Tensor):
+ per_callable_output_was_tensor.append(True)
+ outputs = (outputs,)
+ else:
+ per_callable_output_was_tensor.append(False)
+
+ per_callable_static_outputs.append(outputs)
+
+ # Capture backward graphs in reverse order
+ per_callable_static_grad_outputs = []
+ per_callable_static_grad_inputs = []
+ for static_input_surface, static_outputs, bwd_graph, module_params in \
+ zip(reversed(per_callable_static_input_surfaces),
+ reversed(per_callable_static_outputs),
+ reversed(bwd_graphs),
+ reversed(per_callable_module_params)):
+
+ # For now, assumes all static_outputs require grad
+ assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
+ static_grad_outputs = tuple(torch.empty_like(o) for o in static_outputs)
+
+ with torch.cuda.graph(bwd_graph, pool=mempool):
+ grad_inputs = torch.autograd.grad(outputs=static_outputs,
+ inputs=tuple(i for i in static_input_surface if i.requires_grad),
+ grad_outputs=static_grad_outputs,
+ only_inputs=True,
+ allow_unused=False)
+
+ # Constructs a tuple suitable for returning from Graphed.backward:
+ # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
+ # I couldn't think of a slick one-liner for this pattern.
+ static_grad_inputs = []
+ grad_idx = 0
+ for arg in static_input_surface:
+ if arg.requires_grad:
+ static_grad_inputs.append(grad_inputs[grad_idx])
+ grad_idx += 1
+ else:
+ static_grad_inputs.append(None) # type: ignore[arg-type]
+ static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
+
+ per_callable_static_grad_outputs.append(static_grad_outputs)
+ per_callable_static_grad_inputs.append(static_grad_inputs)
+
+ # Reverses the most recent two lists
+ per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
+ per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
+ # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
+
+ def make_graphed_autograd_function(fwd_graph,
+ bwd_graph,
+ module_params,
+ len_user_args,
+ output_was_tensor,
+ static_input_surface,
+ static_outputs,
+ static_grad_outputs,
+ static_grad_inputs):
+ class Graphed(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ # At this stage, only the user args may (potentially) be new tensors.
+ for i in range(len_user_args):
+ if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+ static_input_surface[i].copy_(inputs[i])
+ fwd_graph.replay()
+ assert isinstance(static_outputs, tuple)
+ return tuple(o.detach() for o in static_outputs)
+
+ @staticmethod
+ @torch.autograd.function.once_differentiable
+ def backward(ctx, *grads):
+ for g, grad in zip(static_grad_outputs, grads):
+ if g is None:
+ assert grad is None
+ else:
+ # don't copy if autograd gods have been kind and the
+ # incoming grad is already in the right place
+ if g.data_ptr() != grad.data_ptr():
+ g.copy_(grad)
+ bwd_graph.replay()
+
+ # Input args that didn't require grad expect a None gradient.
+ assert isinstance(static_grad_inputs, tuple)
+ return tuple(b.detach() if b is not None else b for b in static_grad_inputs)
+
+ def functionalized(*user_args):
+ # Runs the autograd function with inputs == all inputs to the graph that might require grad
+ # (explicit user args + module parameters)
+ # Assumes module params didn't change since capture.
+ out = Graphed.apply(*(user_args + module_params))
+ return out[0] if output_was_tensor else out
+
+ return functionalized
+
+ # Put together the final graphed callables
+ ret = []
+ for i, func in enumerate(callables):
+ graphed = make_graphed_autograd_function(fwd_graphs[i],
+ bwd_graphs[i],
+ per_callable_module_params[i],
+ per_callable_len_user_args[i],
+ per_callable_output_was_tensor[i],
+ per_callable_static_input_surfaces[i],
+ per_callable_static_outputs[i],
+ per_callable_static_grad_outputs[i],
+ per_callable_static_grad_inputs[i])
+
+ if isinstance(func, torch.nn.Module):
+ def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
+ def new_fwd(*user_args):
+ # If the module's training-or-eval state matches what we graphed,
+ # run the graph, otherwise run the original forward method
+ if func.training == graph_training_state:
+ return graphed(*user_args)
+ else:
+ return orig_fwd(*user_args)
+ return new_fwd
+ func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
+ ret.append(func)
+ else:
+ ret.append(graphed)
+
+ if just_one_callable:
+ return ret[0]
+
+ return tuple(ret)
# Define dummy base classes
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase')
torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase')
- torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase')
- torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
class Stream(torch._C._CudaStreamBase):
r"""Wrapper around a CUDA stream.
return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
else:
return '<torch.cuda.Event uninitialized>'
-
-_Graph = torch._C._CudaGraphBase
-_graph_pool_handle = torch._C._graph_pool_handle