[CUDA graphs] Prototype API and documentation (#63269)
authorMichael Carilli <mcarilli@gmail.com>
Tue, 31 Aug 2021 20:29:39 +0000 (13:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 20:34:23 +0000 (13:34 -0700)
Summary:
RFC: https://github.com/pytorch/pytorch/issues/61880

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63269

Reviewed By: mruberry

Differential Revision: D30596643

Pulled By: ngimel

fbshipit-source-id: b1f8061406364b667e2c2d4d30fbce1f0d8456be

aten/src/ATen/cuda/CUDAGraphsUtils.cuh
aten/src/ATen/native/cudnn/Conv_v7.cpp
c10/cuda/CUDACachingAllocator.cpp
docs/source/cuda.rst
docs/source/notes/cuda.rst
test/test_cuda.py
torch/_C/__init__.pyi.in
torch/csrc/cuda/Graph.cpp
torch/cuda/__init__.py
torch/cuda/graphs.py [new file with mode: 0644]
torch/cuda/streams.py

index c25ba88..9d42ed7 100644 (file)
@@ -42,5 +42,18 @@ inline void assertNotCapturing(std::string attempt) {
               status);
 }
 
+inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
+  auto status = currentStreamCaptureStatus();
+  TORCH_CHECK(status == CaptureStatus::None,
+              "Current cudaStreamCaptureStatus: ",
+              status,
+              "\nCapturing ",
+              version_specific,
+              "is prohibited. Possible causes of this error:\n"
+              "1. No warmup iterations occurred before capture.\n"
+              "2. The convolutions you're trying to capture use dynamic shapes, "
+              "in which case capturing them is generally prohibited.");
+}
+
 } // namespace cuda
 } // namespace at
index 7d16f0a..27863d0 100644 (file)
@@ -11,6 +11,7 @@
 #include <ATen/ATen.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/Config.h>
+#include <ATen/cuda/CUDAGraphsUtils.cuh>
 #include <ATen/cuda/Exceptions.h>
 #include <ATen/native/cudnn/ConvShared.h>
 
@@ -292,6 +293,7 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
     } else {
       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
       Workspace ws(max_ws_size);
+      at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
       AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx(
           args.handle,
           args.idesc.desc(), args.input.data_ptr(),
@@ -362,6 +364,7 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
     } else {
       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
       Workspace ws(max_ws_size);
+      at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
       AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardDataAlgorithmEx(
           args.handle,
           args.wdesc.desc(), args.weight.data_ptr(),
@@ -434,6 +437,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
     } else {
       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
       Workspace ws(max_ws_size);
+      at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
       AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardFilterAlgorithmEx(
           args.handle,
           args.idesc.desc(), args.input.data_ptr(),
index 0553753..659fea3 100644 (file)
@@ -308,6 +308,8 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
   } else {
     // It's ok to capture cudaMallocs, as long as we never cudaFree those
     // addresses before replay.
+    // Capturing cudaMalloc behaves nicely: it gives the graph new VA,
+    // but is ignored (won't leakily allocate new memory) in replays.
     at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
     return cudaMalloc(p, size);
   }
index d4783c8..7502933 100644 (file)
@@ -71,6 +71,17 @@ Streams and events
     Stream
     Event
 
+Graphs (prototype)
+------------------
+.. autosummary::
+    :toctree: generated
+    :nosignatures:
+
+    graph_pool_handle
+    CUDAGraph
+    graph
+    make_graphed_callables
+
 Memory management
 -----------------
 .. autosummary::
index 264017f..5d7c0ea 100644 (file)
@@ -262,7 +262,7 @@ have the same stream-semantics relationship as any group of ops::
 BC note: Using grads on the default stream
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-In prior versions of Pytorch (1.9 and earlier), the autograd engine always synced
+In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced
 the default stream with all backward ops, so the following pattern::
 
     with torch.cuda.stream(s):
@@ -270,7 +270,7 @@ the default stream with all backward ops, so the following pattern::
     use grads
 
 was safe as long as ``use grads`` happened on the default stream.
-In present Pytorch, that pattern is no longer safe. If ``backward()``
+In present PyTorch, that pattern is no longer safe. If ``backward()``
 and ``use grads`` are in different stream contexts, you must sync the streams::
 
     with torch.cuda.stream(s):
@@ -513,3 +513,452 @@ by GIL of Python interpreter.
 
 If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use
 `torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`.
+
+.. _cuda-graph-semantics:
+
+CUDA Graphs
+-----------
+
+A CUDA graph is a record of the work (mostly kernels and their arguments) that a
+CUDA stream and its dependent streams perform.
+For general principles and details on the underlying CUDA API, see
+`Getting Started with CUDA Graphs`_ and the
+`Graphs section`_ of the CUDA C Programming Guide.
+
+PyTorch supports the construction of CUDA graphs using `stream capture`_, which puts a
+CUDA stream in *capture mode*. CUDA work issued to a capturing stream doesn't actually
+run on the GPU. Instead, the work is recorded in a graph.
+
+After capture, the graph can be *launched* to run the GPU work as many times as needed.
+Each replay runs the same kernels with the same arguments. For pointer arguments this
+means the same memory addresses are used.
+By filling input memory with new data (e.g., from a new batch) before each replay,
+you can rerun the same work on new data.
+
+Why CUDA Graphs?
+^^^^^^^^^^^^^^^^
+
+Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange for
+**greatly reduced CPU overhead**. A graph's arguments and kernels are fixed, so a graph replay
+skips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driver
+overheads. Under the hood, a replay submits the entire graph's work to the GPU with
+a single call to `cudaGraphLaunch`_.  Kernels in a replay also execute slightly faster
+on the GPU, but eliding CPU overhead is the main benefit.
+
+You should try CUDA graphs if all or part of your network is graph-safe (usually this means
+static shapes and static control flow, but see the other :ref:`constraints<capture-constraints>`)
+and you suspect its runtime is at least somewhat CPU-limited.
+
+.. _Getting Started with CUDA Graphs:
+    https://developer.nvidia.com/blog/cuda-graphs/
+.. _Graphs section:
+    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs
+.. _stream capture:
+    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
+.. _cudaGraphLaunch:
+    https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
+
+PyTorch API
+^^^^^^^^^^^
+
+.. warning::
+    This API is a prototype and may change in future releases.
+
+PyTorch exposes graphs via a raw :class:`torch.cuda.CUDAGraph` class
+and two convenience wrappers,
+:class:`torch.cuda.graph` and
+:class:`torch.cuda.make_graphed_callables`.
+
+:class:`torch.cuda.graph` is a simple, versatile context manager that
+captures CUDA work in its context.
+Before capture, warm up the workload to be captured by running
+a few eager iterations. Warmup must occur on a side stream.
+Because the graph reads from and writes to the same memory addresses in every
+replay, you must maintain long-lived references to tensors that hold
+input and output data during capture.
+To run the graph on new input data, copy new data to the capture's input tensor(s),
+replay the graph, then read the new output from the capture's output tensor(s).
+Example::
+
+    g = torch.cuda.CUDAGraph()
+
+    # Placeholder input used for capture
+    static_input = torch.empty((5,), device="cuda")
+
+    # Warmup before capture
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for _ in range(3):
+            static_output = static_input * 2
+    torch.cuda.current_stream().wait_stream(s)
+
+    # Captures the graph
+    # To allow capture, automatically sets a side stream as the current stream in the context
+    with torch.cuda.graph(g):
+        static_output = static_input * 2
+
+    # Fills the graph's input memory with new data to compute on
+    static_input.copy_(torch.full((5,), 3, device="cuda"))
+    g.replay()
+    # static_output holds the results
+    print(static_output)  # full of 3 * 2 = 6
+
+    # Fills the graph's input memory with more data to compute on
+    static_input.copy_(torch.full((5,), 4, device="cuda"))
+    g.replay()
+    print(static_output)  # full of 4 * 2 = 8
+
+See
+:ref:`Whole-network capture<whole-network-capture>`,
+:ref:`Usage with torch.cuda.amp<graphs-with-amp>`, and
+:ref:`Usage with multiple streams<multistream-capture>`
+for realistic and advanced patterns.
+
+:class:`~torch.cuda.make_graphed_callables` is more sophisticated.
+:class:`~torch.cuda.make_graphed_callables` accepts Python functions and
+:class:`torch.nn.Module`\s. For each passed function or Module,
+it creates separate graphs of the forward-pass and backward-pass work. See
+:ref:`Partial-network capture<partial-network-capture>`.
+
+.. _capture-constraints:
+
+Constraints
+~~~~~~~~~~~
+
+A set of ops is *capturable* if it doesn't violate any of the following constraints.
+
+Constraints apply to all work in a
+:class:`torch.cuda.graph` context and all work in the forward and backward passes
+of any callable you pass to :func:`torch.cuda.make_graphed_callables`.
+
+Violating any of these will likely cause a runtime error:
+
+* Capture must occur on a non-default stream. (This is only a concern if you use the raw
+  :meth:`CUDAGraph.capture_begin<torch.cuda.CUDAGraph.capture_begin>` and
+  :meth:`CUDAGraph.capture_end<torch.cuda.CUDAGraph.capture_end>` calls.
+  :class:`~torch.cuda.graph` and
+  :func:`~torch.cuda.make_graphed_callables` set a side stream for you.)
+* Ops that sychronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited.
+* CUDA RNG ops are allowed, but must use default generators. For example, explicitly constructing a
+  new :class:`torch.Generator` instance and passing it as the ``generator`` argument to an RNG function
+  is prohibited.
+
+Violating any of these will likely cause silent numerical errors or undefined behavior:
+
+* Within a process, only one capture may be underway at a time.
+* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
+* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
+* Every replay reads from and writes to the same (virtual) memory addresses.
+* Dynamic control flow (based on CPU or GPU data) is prohibited.
+* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
+  has the same size and layout in every replay.
+* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
+
+Non-constraints
+~~~~~~~~~~~~~~~
+
+* Once captured, the graph may be replayed on any stream.
+
+.. _whole-network-capture:
+
+Whole-network capture
+^^^^^^^^^^^^^^^^^^^^^^
+
+If your entire network is capturable, you can capture and replay an entire iteration::
+
+    N, D_in, H, D_out = 640, 4096, 2048, 1024
+    model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
+                                torch.nn.Dropout(p=0.2),
+                                torch.nn.Linear(H, D_out),
+                                torch.nn.Dropout(p=0.1)).cuda()
+    loss_fn = torch.nn.MSELoss()
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
+
+    # Placeholders used for capture
+    static_input = torch.randn(N, D_in, device='cuda')
+    static_target = torch.randn(N, D_out, device='cuda')
+
+    # warmup
+    # Uses static_input and static_target here for convenience,
+    # but in a real setting, because the warmup includes optimizer.step()
+    # you must use a few batches of real data.
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for i in range(3):
+            optimizer.zero_grad(set_to_none=True)
+            y_pred = model(static_input)
+            loss = loss_fn(y_pred, static_target)
+            loss.backward()
+            optimizer.step()
+    torch.cuda.current_stream().wait_stream(s)
+
+    # capture
+    g = torch.cuda.CUDAGraph()
+    # Sets grads to None before capture, so backward() will create
+    # .grad attributes with allocations from the graph's private pool
+    optimizer.zero_grad(set_to_none=True)
+    with torch.cuda.graph(g):
+        static_y_pred = model(static_input)
+        static_loss = loss_fn(static_y_pred, static_target)
+        static_loss.backward()
+        optimizer.step()
+
+    real_inputs = [torch.rand_like(static_input) for _ in range(10)]
+    real_targets = [torch.rand_like(static_target) for _ in range(10)]
+
+    for data, target in zip(real_inputs, real_targets):
+        # Fills the graph's input memory with new data to compute on
+        static_input.copy_(data)
+        static_target.copy_(target)
+        # replay() includes forward, backward, and step.
+        # You don't even need to call optimizer.zero_grad() between iterations
+        # because the captured backward refills static .grad tensors in place.
+        g.replay()
+        # Params have been updated. static_y_pred, static_loss, and .grad
+        # attributes hold values from computing on this iteration's data.
+
+.. _partial-network-capture:
+
+Partial-network capture
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If some of your network is unsafe to capture (e.g., due to dynamic control flow,
+dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe
+part(s) eagerly and use :func:`torch.cuda.make_graphed_callables` to graph only
+the capture-safe part(s).
+
+By default, callables returned by :func:`~torch.cuda.make_graphed_callables`
+are autograd-aware, and can be used in the training loop as direct replacements
+for the functions or :class:`nn.Module<torch.nn.Module>`\ s you passed.
+
+:func:`~torch.cuda.make_graphed_callables` internally creates
+:class:`~torch.cuda.CUDAGraph` objects, runs warmup iterations, and maintains
+static inputs and outputs as needed.  Therefore (unlike with
+:class:`torch.cuda.graph`) you don't need to handle those manually.
+
+In the following example, data-dependent dynamic control flow means the
+network isn't capturable end-to-end, but
+:func:`~torch.cuda.make_graphed_callables`
+lets us capture and run graph-safe sections as graphs regardless::
+
+    N, D_in, H, D_out = 640, 4096, 2048, 1024
+
+    module1 = torch.nn.Linear(D_in, H).cuda()
+    module2 = torch.nn.Linear(H, D_out).cuda()
+    module3 = torch.nn.Linear(H, D_out).cuda()
+
+    loss_fn = torch.nn.MSELoss()
+    optimizer = torch.optim.SGD(chain(module1.parameters() +
+                                      module2.parameters() +
+                                      module3.parameters()),
+                                lr=0.1)
+
+    # Sample inputs used for capture
+    # requires_grad state of sample inputs must match
+    # requires_grad state of real inputs each callable will see.
+    x = torch.randn(N, D_in, device='cuda')
+    h = torch.randn(N, H, device='cuda', requires_grad=True)
+
+    module1 = torch.cuda.make_graphed_callables(module1, (x,))
+    module2 = torch.cuda.make_graphed_callables(module2, (h,))
+    module3 = torch.cuda.make_graphed_callables(module3, (h,))
+
+    real_inputs = [torch.rand_like(x) for _ in range(10)]
+    real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]
+
+    for data, target in zip(real_inputs, real_targets):
+        optimizer.zero_grad(set_to_none=True)
+
+        tmp = module1(data)  # forward ops run as a graph
+
+        if tmp.sum().item() > 0:
+            tmp = module2(tmp)  # forward ops run as a graph
+        else:
+            tmp = module3(tmp)  # forward ops run as a graph
+
+        loss = loss_fn(tmp, y)
+        # module2's or module3's (whichever was chosen) backward ops,
+        # as well as module1's backward ops, run as graphs
+        loss.backward()
+        optimizer.step()
+
+.. _graphs-with-amp:
+
+Usage with torch.cuda.amp
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For typical optimizers, :meth:`GradScaler.step<torch.cuda.amp.GradScaler.step>` syncs
+the CPU with the GPU, which is prohibited during capture. To avoid errors, either use
+:ref:`partial-network capture<partial-network-capture>`, or (if forward, loss,
+and backward are capture-safe) capture forward, loss, and backward but not the
+optimizer step::
+
+    # warmup
+    # In a real setting, use a few batches of real data.
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for i in range(3):
+            optimizer.zero_grad(set_to_none=True)
+            with torch.cuda.amp.autocast():
+                y_pred = model(static_input)
+                loss = loss_fn(y_pred, static_target)
+            scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
+    torch.cuda.current_stream().wait_stream(s)
+
+    # capture
+    g = torch.cuda.CUDAGraph()
+    optimizer.zero_grad(set_to_none=True)
+    with torch.cuda.graph(g):
+        with torch.cuda.amp.autocast():
+            static_y_pred = model(static_input)
+            static_loss = loss_fn(static_y_pred, static_target)
+        scaler.scale(static_loss).backward()
+        # don't capture scaler.step(optimizer) or scaler.update()
+
+    real_inputs = [torch.rand_like(static_input) for _ in range(10)]
+    real_targets = [torch.rand_like(static_target) for _ in range(10)]
+
+    for data, target in zip(real_inputs, real_targets):
+        static_input.copy_(data)
+        static_target.copy_(target)
+        g.replay()
+        # Runs scaler.step and scaler.update eagerly
+        scaler.step(optimizer)
+        scaler.update()
+
+.. _multistream-capture:
+
+Usage with multiple streams
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Capture mode automatically propagates to any streams that sync with a capturing stream.
+Within capture, you may expose parallelism by issuing calls to different streams,
+but the overall stream dependency DAG must branch out from the
+initial capturing stream after capture begins and rejoin the initial stream
+before capture ends::
+
+    with torch.cuda.graph(g):
+        # at context manager entrance, torch.cuda.current_stream()
+        # is the initial capturing stream
+
+        # INCORRECT (does not branch out from or rejoin initial stream)
+        with torch.cuda.stream(s):
+            cuda_work()
+
+        # CORRECT:
+        # branches out from initial stream
+        s.wait_stream(torch.cuda.current_stream())
+        with torch.cuda.stream(s):
+            cuda_work()
+        # rejoins initial stream before capture ends
+        torch.cuda.current_stream().wait_stream(s)
+
+.. note::
+
+    To avoid confusion for power users looking at replays in nsight systems or nvprof:
+    Unlike eager execution, the graph interprets a nontrivial stream DAG in capture
+    as a hint, not a command. During replay, the graph may reorganize independent ops
+    onto different streams or enqueue them in a different order (while respecting your
+    original DAG's overall dependencies).
+
+Usage with DistributedDataParallel
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+NCCL < 2.9.6
+~~~~~~~~~~~~
+
+NCCL versions earlier than 2.9.6 don't allow collectives to be captured.
+You must use :ref:`partial-network capture<partial-network-capture>`,
+which defers allreduces to happen outside graphed sections of backward.
+
+Call :func:`~torch.cuda.make_graphed_callables` on graphable network sections
+*before* wrapping the network with DDP.
+
+NCCL >= 2.9.6
+~~~~~~~~~~~~~
+
+NCCL versions 2.9.6 or later allow collectives in the graph.
+Approaches that capture an :ref:`entire backward pass<whole-network-capture>`
+are a viable option, but need three setup steps.
+
+1. Disable DDP's internal async error handling::
+
+    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
+    torch.distributed.init_process_group(...)
+
+2. Before full-backward capture, DDP must be constructed in a side-stream context::
+
+    with torch.cuda.stream(s):
+        model = DistributedDataParallel(model)
+
+3. Your warmup must run at least 11 DDP-enabled eager iterations before capture.
+
+.. _graph-memory-management:
+
+Graph memory management
+^^^^^^^^^^^^^^^^^^^^^^^
+
+A captured graph acts on the same virtual addresses every time it replays.
+If PyTorch frees the memory, a later replay can hit an illegal memory access.
+If PyTorch reassigns the memory to new tensors, the replay can corrupt the values
+seen by those tensors.  Therefore, the virtual addresses used by the graph must be
+reserved for the graph across replays. The PyTorch caching allocator achieves this
+by detecting when capture is underway and satisfying the capture's allocations
+from a graph-private memory pool. The private pool stays alive until its
+:class:`~torch.cuda.CUDAGraph` object and all tensors created during capture
+go out of scope.
+
+Private pools are maintained automatically. By default, the allocator creates a
+separate private pool for each capture. If you capture multiple graphs,
+this conservative approach ensures graph replays never corrupt each other's values,
+but sometimes needlessly wastes memory.
+
+To economize the memory stashed in private pools, :class:`torch.cuda.graph`
+and :func:`torch.cuda.make_graphed_callables` optionally allow different
+captures to share the same private pool.
+It's safe for a set of graphs to share a private pool if you know they'll always
+be replayed in the same order they were captured,
+and never be replayed concurrently.
+
+Sharing memory across captures with torch.cuda.graph
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+:class:`torch.cuda.graph`'s ``pool`` argument is a hint to use a particular private pool,
+and can be used to share memory across graphs as shown::
+
+    g1 = torch.cuda.CUDAGraph()
+    g2 = torch.cuda.CUDAGraph()
+
+    # (create static inputs for g1 and g2, run warmups of their workloads...)
+
+    # Captures g1
+    with torch.cuda.graph(g1):
+        static_out_1 = g1_workload(static_in_1)
+
+    # Captures g2, hinting that g2 may share a memory pool with g1
+    with torch.cuda.graph(g2, pool=g1.pool()):
+        static_out_2 = g2_workload(static_in_2)
+
+    static_in_1.copy_(real_data_1)
+    static_in_2.copy_(real_data_2)
+    g1.replay()
+    g2.replay()
+
+Sharing memory across captures with torch.cuda.make_graphed_callables
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+With :func:`torch.cuda.make_graphed_callables`, if you want to graph several
+callables and you know they'll always run in the same order (and never concurrently)
+pass them as a tuple in the same order they'll run in the live workload, and
+:func:`~torch.cuda.make_graphed_callables` will capture their graphs using a shared
+private pool.
+
+If, in the live workload, your callables will run in an order that occasionally changes,
+or if they'll run concurrently, passing them as a tuple to a single invocation of
+:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
+:func:`~torch.cuda.make_graphed_callables` separately for each one.
index e90cb17..70f5a6e 100644 (file)
@@ -3089,7 +3089,7 @@ torch.cuda.synchronize()
 
         with torch.cuda.stream(s):
             a = torch.full((1000,), 1, device="cuda")
-            g = torch.cuda._Graph()
+            g = torch.cuda.CUDAGraph()
             torch.cuda.empty_cache()
             g.capture_begin()
             b = a
@@ -3125,7 +3125,7 @@ torch.cuda.synchronize()
             with torch.cuda.stream(stream):
                 torch.cuda.manual_seed(5)
 
-                g = torch.cuda._Graph()
+                g = torch.cuda.CUDAGraph()
                 torch.cuda.empty_cache()
                 g.capture_begin()
                 graph_out = graph_in
@@ -3212,7 +3212,7 @@ torch.cuda.synchronize()
             with torch.cuda.stream(stream):
                 torch.cuda.manual_seed(5)
 
-                g = torch.cuda._Graph()
+                g = torch.cuda.CUDAGraph()
                 torch.cuda.empty_cache()
                 if (module == "torch"):
                     g.capture_begin()
@@ -3279,14 +3279,14 @@ torch.cuda.synchronize()
         s = torch.cuda.Stream()
 
         for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
-            g0 = torch.cuda._Graph()
-            g1 = torch.cuda._Graph()
+            g0 = torch.cuda.CUDAGraph()
+            g1 = torch.cuda.CUDAGraph()
 
             a = torch.ones((size,), device="cuda")
 
             s.wait_stream(torch.cuda.current_stream())
             with torch.cuda.stream(s):
-                g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+                g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
                 g0.capture_begin(*g0_args)
                 b = a.clone()
                 for _ in range(5):
@@ -3343,8 +3343,8 @@ torch.cuda.synchronize()
         s = torch.cuda.Stream()
 
         for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
-            g0 = torch.cuda._Graph()
-            g1 = torch.cuda._Graph()
+            g0 = torch.cuda.CUDAGraph()
+            g1 = torch.cuda.CUDAGraph()
 
             s0 = torch.cuda.Stream()
             s1 = torch.cuda.Stream()
@@ -3353,7 +3353,7 @@ torch.cuda.synchronize()
 
             s.wait_stream(torch.cuda.current_stream())
             with torch.cuda.stream(s):
-                g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+                g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
                 g0.capture_begin(*g0_args)
                 b = a.clone()
                 for _ in range(5):
@@ -3407,13 +3407,13 @@ torch.cuda.synchronize()
         for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
             a = torch.ones((size,), device="cuda")
 
-            g0 = torch.cuda._Graph()
-            g1 = torch.cuda._Graph()
-            g2 = torch.cuda._Graph()
+            g0 = torch.cuda.CUDAGraph()
+            g1 = torch.cuda.CUDAGraph()
+            g2 = torch.cuda.CUDAGraph()
 
             s.wait_stream(torch.cuda.current_stream())
             with torch.cuda.stream(s):
-                g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
+                g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else ()
                 g0.capture_begin(*g0_args)
                 b = a.clone()
                 c = b + 1
@@ -3499,7 +3499,7 @@ torch.cuda.synchronize()
                 delta_active_blocks = 1  # We only check the large pool, which isn't affected by rng offset holder
                 delta_active_bytes = numel * elem
 
-            g = torch.cuda._Graph()
+            g = torch.cuda.CUDAGraph()
             s.wait_stream(torch.cuda.current_stream())
             with torch.cuda.stream(s):
                 # Allocation stat estimates assume input is created on the same stream as capture_begin()
@@ -3573,7 +3573,7 @@ torch.cuda.synchronize()
         s0 = torch.cuda.Stream()
         s1 = torch.cuda.Stream()
         s2 = torch.cuda.Stream()
-        g = torch.cuda._Graph()
+        g = torch.cuda.CUDAGraph()
 
         torch.cuda.synchronize()
         with torch.cuda.stream(s0):
@@ -3620,7 +3620,7 @@ torch.cuda.synchronize()
 
         y = model(x)
 
-        g = torch.cuda._Graph()
+        g = torch.cuda.CUDAGraph()
         s = torch.cuda.Stream()
         s.wait_stream(torch.cuda.current_stream())
         with torch.cuda.stream(s):
@@ -3638,7 +3638,7 @@ torch.cuda.synchronize()
         torch.cuda.empty_cache()
 
         scaler = torch.cuda.amp.GradScaler(init_scale=4.)
-        g = torch.cuda._Graph()
+        g = torch.cuda.CUDAGraph()
         s = torch.cuda.Stream()
 
         weight = torch.ones((100,), device="cuda", requires_grad=True)
@@ -3646,18 +3646,15 @@ torch.cuda.synchronize()
         static_input = torch.ones_like(weight)
         static_grad = torch.ones_like(weight)
 
-        s.wait_stream(torch.cuda.current_stream())
-        with torch.cuda.stream(s):
-            # warmup
-            loss = (weight.half() * static_input).sum()
-            scaler.scale(loss).backward()
-            opt.zero_grad(set_to_none=True)
-            # capture
-            g.capture_begin()
+        # warmup
+        loss = (weight.half() * static_input).sum()
+        scaler.scale(loss).backward()
+        opt.zero_grad(set_to_none=True)
+
+        # capture
+        with torch.cuda.graph(g):
             loss = (weight.half() * static_input).sum()
             scaler.scale(loss).backward()
-            g.capture_end()
-        torch.cuda.current_stream().wait_stream(s)
 
         input_vals = [5, 20000, 5, 40000]
         # If the scale gets updated properly, these are the scale, growth tracker,
@@ -3678,6 +3675,71 @@ torch.cuda.synchronize()
             self.assertEqual(scaler._scale, scale)
             self.assertEqual(scaler._growth_tracker, growth_tracker)
 
+    @unittest.skipIf((not TEST_CUDA) or
+                     TEST_WITH_ROCM or
+                     int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
+    def test_graph_make_graphed_callables(self):
+        torch.manual_seed(5)
+        torch.cuda.manual_seed(5)
+
+        N, D_in, H, D_out = 640, 4096, 2048, 1024
+
+        models = []
+        for _ in range(2):
+            model_section1 = torch.nn.Sequential(torch.nn.Linear(D_in, H),
+                                                 torch.nn.Dropout(p=0.1)).cuda()
+            model_section2 = torch.nn.Sequential(torch.nn.Linear(H, D_out),
+                                                 torch.nn.Dropout(p=0.2)).cuda()
+            models.append(torch.nn.Sequential(model_section1, model_section2))
+
+        model_graphed = models[0]
+        model_control = models[1]
+
+        model_graphed.load_state_dict(model_control.state_dict())
+
+        opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
+        opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
+
+        x = torch.randn(N, D_in, device='cuda')
+        h = torch.randn(N, H, device='cuda', requires_grad=True)
+        y_pred = torch.randn(N, D_out, device='cuda', requires_grad=True)
+        y = torch.randn(N, D_out, device='cuda')
+
+        loss_fn_control = torch.nn.functional.mse_loss
+        relu_control = torch.nn.functional.relu
+
+        # This is a good stress test. It graphs four callables: two Modules and two python functions.
+        model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \
+            torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control),
+                                              ((x,), (h,), (y_pred,), (y_pred, y)))
+
+        real_inputs = [torch.rand_like(x) for _ in range(10)]
+        real_targets = [torch.rand_like(y) for _ in range(10)]
+
+        for m, opt, relu, loss_fn in zip((model_graphed, model_control),
+                                         (opt_graphed, opt_control),
+                                         (relu_graphed, relu_control),
+                                         (loss_fn_graphed, loss_fn_control)):
+            # Resets RNC states before iterations for graphed and ungraphed models,
+            # so dropout math should be bitwise identical for both.
+            torch.manual_seed(5)
+            torch.cuda.manual_seed(5)
+            for data, target in zip(real_inputs, real_targets):
+                opt.zero_grad(set_to_none=True)
+                y_pred = m(data)
+                y_pred = relu(y_pred)
+                loss = loss_fn(y_pred, target)
+                loss.backward()
+                opt.step()
+
+        for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
+            self.assertEqual(p, pc)
+
+        # We graphed the models in training mode. Eval should still run ungraphed.
+        model_graphed.eval()
+        model_control.eval()
+        self.assertEqual(model_graphed(real_inputs[0]), model_control(real_inputs[0]))
+
     def test_batch_norm_gather_stats(self):
         input = torch.randn(1, 3, 3, 3, device='cuda')
         mean, invstd = torch.batch_norm_gather_stats(
index c847e8d..352edbe 100644 (file)
@@ -888,8 +888,13 @@ class _CudaEventBase:
     def ipc_handle(self) -> bytes: ...
 
 # Defined in torch/csrc/cuda/Graph.cpp
-class _CudaGraphBase:
-    ...
+class _CUDAGraph:
+    def capture_begin(self,
+                      pool: Optional[Tuple[_int, _int]]=...) -> None: ...
+    def capture_end(self) -> None: ...
+    def replay(self) -> None: ...
+    def reset(self) -> None: ...
+    def pool(self) -> Tuple[_int, _int]: ...
 
 def _graph_pool_handle() -> Tuple[_int, _int]: ...
 
index 123abb9..beacefa 100644 (file)
@@ -23,36 +23,29 @@ void THCPGraph_init(PyObject *module) {
   auto torch_C_m = py::handle(module).cast<py::module>();
 
   torch_C_m
-      .def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
+      .def("_graph_pool_handle",
+           &::at::cuda::graph_pool_handle);
 
-  shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CudaGraphBase")
+  shared_ptr_class_<::at::cuda::CUDAGraph>
+      (torch_C_m,
+       "_CUDAGraph")
       .def(py::init<>())
       // I'm not sure this is the correct order of all the arguments. Pybind11 docs
       // aren't clear. But it works.
       .def("capture_begin",
            &::at::cuda::CUDAGraph::capture_begin,
            py::call_guard<py::gil_scoped_release>(),
-           R"(``capture_begin`` begins Cuda graph capture on the current stream.)",
            py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
       .def("capture_end",
            &::at::cuda::CUDAGraph::capture_end,
-           py::call_guard<py::gil_scoped_release>(),
-           R"(``capture_end`` ends Cuda graph capture on the current stream.
-           After ``capture_end``, ``replay`` may be called on this instance.)")
+           py::call_guard<py::gil_scoped_release>())
       .def("replay",
            &::at::cuda::CUDAGraph::replay,
-           py::call_guard<py::gil_scoped_release>(),
-           R"(``replay`` replays the Cuda graph captured by this instance.)")
-      // reset is called in __del__ on the Python side
-      // (see class Graph in torch/cuda/streams.py for reasons and caveats)
+           py::call_guard<py::gil_scoped_release>())
       .def("reset",
            &::at::cuda::CUDAGraph::reset,
-           py::call_guard<py::gil_scoped_release>(),
-           R"(``reset`` deletes the graph currently held by this instance.)")
+           py::call_guard<py::gil_scoped_release>())
       .def("pool",
            &::at::cuda::CUDAGraph::pool,
-           py::call_guard<py::gil_scoped_release>(),
-           R"(``pool`` retrieves the id of this graph's memory pool.
-           This id can optionally be passed to another graph's capture_begin,
-           which hints that other graph may share the same memory pool.)");
+           py::call_guard<py::gil_scoped_release>());
 }
index d5a9cbb..924782d 100644 (file)
@@ -16,7 +16,8 @@ import warnings
 import threading
 from typing import List, Optional, Tuple, Union, Any
 from ._utils import _get_device_index, _dummy_type
-from .streams import Stream, Event, _Graph, _graph_pool_handle
+from .graphs import CUDAGraph, graph_pool_handle, graph, make_graphed_callables
+from .streams import Stream, Event
 from .. import device as _device
 import torch._C
 
diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py
new file mode 100644 (file)
index 0000000..ff8a07f
--- /dev/null
@@ -0,0 +1,408 @@
+import gc
+import torch
+
+from ._utils import _dummy_type
+
+
+if not hasattr(torch._C, '_CudaStreamBase'):
+    # Define dummy base classes
+    torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph')
+    torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
+
+from torch._C import _CUDAGraph  # noqa: F401
+from torch._C import _graph_pool_handle
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+def graph_pool_handle():
+    r"""
+    Returns an opaque token representing the id of a graph memory pool.
+    See :ref:`Graph memory management<graph-memory-management>`.
+
+    .. warning::
+        This API is a prototype and may change in future releases.
+    """
+    return _graph_pool_handle()
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+class CUDAGraph(torch._C._CUDAGraph):
+    r"""
+    Wrapper around a CUDA graph.
+
+    .. warning::
+        This API is a prototype and may change in future releases.
+    """
+    def __new__(cls):
+        return super(CUDAGraph, cls).__new__(cls)
+
+    def __init__(self):
+        super(CUDAGraph, self).__init__()
+
+    def capture_begin(self, pool=None):
+        r"""
+        Begins capturing CUDA work on the current stream.
+
+        Typically, you shouldn't call ``capture_begin`` yourself.
+        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+        which call ``capture_begin`` internally.
+
+        Arguments:
+            pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
+                :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
+                with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
+        """
+        # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
+        # so I'm not taking any chances.
+        if pool is None:
+            super(CUDAGraph, self).capture_begin()
+        else:
+            super(CUDAGraph, self).capture_begin(pool)
+
+    def capture_end(self):
+        r"""
+        Ends CUDA graph capture on the current stream.
+        After ``capture_end``, ``replay`` may be called on this instance.
+
+        Typically, you shouldn't call ``capture_end`` yourself.
+        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+        which call ``capture_end`` internally.
+        """
+        super(CUDAGraph, self).capture_end()
+
+    def replay(self):
+        r"""
+        Replays the CUDA work captured by this graph.
+        """
+        super(CUDAGraph, self).replay()
+
+    def reset(self):
+        r"""
+        Deletes the graph currently held by this instance.
+        """
+        super(CUDAGraph, self).reset()
+
+    def pool(self):
+        r"""
+        Returns an opaque token representing the id of this graph's memory pool.
+        This id can optionally be passed to another graph's ``capture_begin``,
+        which hints the other graph may share the same memory pool.
+        """
+        return super(CUDAGraph, self).pool()
+
+
+class graph(object):
+    r"""
+    Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph`
+    object for later replay.
+
+    See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
+    detailed use, and constraints.
+
+    Arguments:
+        cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
+        pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
+            :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
+            may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
+        stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
+            If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
+
+    .. note::
+        For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
+        used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
+
+    .. warning::
+        This API is a prototype and may change in future releases.
+    """
+    default_capture_stream = None
+
+    def __init__(self,
+                 cuda_graph,
+                 pool=None,
+                 stream=None):
+        # Lazy-init of default_capture_stream helps avoid circular-import errors.
+        # Not thread safe, but graphs already have the general (explicitly documented)
+        # restriction that only one capture may be underway at a time in the process.
+        if self.__class__.default_capture_stream is None:
+            self.__class__.default_capture_stream = torch.cuda.Stream()
+
+        self.pool = () if pool is None else (pool,)
+        self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream
+        assert self.capture_stream is not None
+        self.stream_ctx = torch.cuda.stream(self.capture_stream)
+        self.cuda_graph = cuda_graph
+
+    def __enter__(self):
+        # Free as much memory as we can for the graph
+        torch.cuda.synchronize()
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        # Stackoverflow seems comfortable with this pattern
+        # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
+        self.stream_ctx.__enter__()
+
+        self.cuda_graph.capture_begin(*self.pool)
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.cuda_graph.capture_end()
+        self.stream_ctx.__exit__(exc_type, exc_value, traceback)
+        # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
+
+
+def make_graphed_callables(callables, sample_args):
+    r"""
+    Accepts callables (functions or :class:`nn.Module<torch.nn.Module>`\ s)
+    and returns graphed versions.
+
+    Each graphed callable's forward pass runs its source callable's
+    forward CUDA work as a CUDA graph inside a single autograd node.
+
+    The graphed callable's forward pass also appends
+    a backward node to the autograd graph. During backward, this node runs the
+    callable's backward work as a CUDA graph.
+
+    Therefore, each graphed callable should be a drop-in replacement for its source callable
+    in an autograd-enabled training loop.
+
+    See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
+
+    If you pass a tuple of several callables, their captures will use the same memory pool.
+    See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
+
+    Arguments:
+        callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
+            See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
+            is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
+            they'll run in the live workload.
+        sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
+            If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
+            If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
+
+    .. note::
+        The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
+        that's expected for the corresponding real input in the training loop.
+
+    .. warning::
+        This API is a prototype and may change in future releases.
+
+    .. warning::
+        ``sample_args`` for each callable must be a tuple of Tensors. Other types and keyword args
+        are not allowed.
+
+    .. warning::
+        Returned callables do not support higher order differentiation (e.g., double backward).
+
+    .. warning::
+        In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
+        may be trainable. Buffers must have ``requires_grad=False``.
+
+    .. warning::
+        After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
+        you may not add or remove any of that Module's parameters or buffers.
+
+    .. warning::
+        :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
+        registered on them at the time they are passed. However, registering hooks on modules *after* passing them
+        through :func:`~torch.cuda.make_graphed_callables` is allowed.
+
+    .. warning::
+        When running a graphed callable, you must pass its arguments in the same order and format
+        they appeared in that callable's ``sample_args``.
+
+    .. warning::
+        All Tensor outputs of graphed callables must require grad.
+    """
+    just_one_callable = False
+
+    if not isinstance(callables, tuple):
+        just_one_callable = True
+        callables = (callables,)
+        sample_args = (sample_args,)
+
+    for c, args in zip(callables, sample_args):
+        if isinstance(c, torch.nn.Module):
+            assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \
+                "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \
+                "on modules after passing them through make_graphed_callables is allowed."
+            assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \
+                ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \
+                "``requires_grad=False``."
+        assert all(isinstance(arg, torch.Tensor) for arg in args), "In the prototype API, sample_args " + \
+            "for each callable must be a tuple of Tensors. Other types and keyword args are not allowed."
+
+
+    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
+    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
+    per_callable_len_user_args = [len(args) for args in sample_args]
+    per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
+                                  for c in callables]
+    per_callable_static_input_surfaces = [sample_args[i] + per_callable_module_params[i]
+                                          for i in range(len(callables))]
+
+    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+
+    mempool = graph_pool_handle()
+
+    # Warmup
+    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+    # from ending up in any captures.
+    torch.cuda.synchronize()
+    with torch.cuda.stream(torch.cuda.Stream()):
+        for func, args, static_input_surface in zip(callables,
+                                                    sample_args,
+                                                    per_callable_static_input_surfaces):
+            for _ in range(3):
+                outputs = func(*args)
+                outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
+                grad_inputs = torch.autograd.grad(outputs=outputs,
+                                                  inputs=tuple(i for i in static_input_surface if i.requires_grad),
+                                                  grad_outputs=tuple(torch.empty_like(o) for o in outputs),
+                                                  only_inputs=True,
+                                                  allow_unused=False)
+            del outputs, grad_inputs
+    torch.cuda.synchronize()
+
+    # All captures here share a mempool. To avoid replays corrupting each other's memory,
+    # the safest approach is to capture all passes in the same order they'll run:
+    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
+
+    # Capture forward graphs
+    per_callable_static_outputs = []
+    per_callable_output_was_tensor = []
+    for func, args, fwd_graph in zip(callables,
+                                     sample_args,
+                                     fwd_graphs):
+        with torch.cuda.graph(fwd_graph, pool=mempool):
+            outputs = func(*args)
+
+        # Assumes model output is a tensor or tuple of tensors
+        if isinstance(outputs, torch.Tensor):
+            per_callable_output_was_tensor.append(True)
+            outputs = (outputs,)
+        else:
+            per_callable_output_was_tensor.append(False)
+
+        per_callable_static_outputs.append(outputs)
+
+    # Capture backward graphs in reverse order
+    per_callable_static_grad_outputs = []
+    per_callable_static_grad_inputs = []
+    for static_input_surface, static_outputs, bwd_graph, module_params in \
+            zip(reversed(per_callable_static_input_surfaces),
+                reversed(per_callable_static_outputs),
+                reversed(bwd_graphs),
+                reversed(per_callable_module_params)):
+
+        # For now, assumes all static_outputs require grad
+        assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
+        static_grad_outputs = tuple(torch.empty_like(o) for o in static_outputs)
+
+        with torch.cuda.graph(bwd_graph, pool=mempool):
+            grad_inputs = torch.autograd.grad(outputs=static_outputs,
+                                              inputs=tuple(i for i in static_input_surface if i.requires_grad),
+                                              grad_outputs=static_grad_outputs,
+                                              only_inputs=True,
+                                              allow_unused=False)
+
+        # Constructs a tuple suitable for returning from Graphed.backward:
+        # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
+        # I couldn't think of a slick one-liner for this pattern.
+        static_grad_inputs = []
+        grad_idx = 0
+        for arg in static_input_surface:
+            if arg.requires_grad:
+                static_grad_inputs.append(grad_inputs[grad_idx])
+                grad_idx += 1
+            else:
+                static_grad_inputs.append(None)  # type: ignore[arg-type]
+        static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]
+
+        per_callable_static_grad_outputs.append(static_grad_outputs)
+        per_callable_static_grad_inputs.append(static_grad_inputs)
+
+    # Reverses the most recent two lists
+    per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
+    per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
+    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
+
+    def make_graphed_autograd_function(fwd_graph,
+                                       bwd_graph,
+                                       module_params,
+                                       len_user_args,
+                                       output_was_tensor,
+                                       static_input_surface,
+                                       static_outputs,
+                                       static_grad_outputs,
+                                       static_grad_inputs):
+        class Graphed(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, *inputs):
+                # At this stage, only the user args may (potentially) be new tensors.
+                for i in range(len_user_args):
+                    if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+                        static_input_surface[i].copy_(inputs[i])
+                fwd_graph.replay()
+                assert isinstance(static_outputs, tuple)
+                return tuple(o.detach() for o in static_outputs)
+
+            @staticmethod
+            @torch.autograd.function.once_differentiable
+            def backward(ctx, *grads):
+                for g, grad in zip(static_grad_outputs, grads):
+                    if g is None:
+                        assert grad is None
+                    else:
+                        # don't copy if autograd gods have been kind and the
+                        # incoming grad is already in the right place
+                        if g.data_ptr() != grad.data_ptr():
+                            g.copy_(grad)
+                bwd_graph.replay()
+
+                # Input args that didn't require grad expect a None gradient.
+                assert isinstance(static_grad_inputs, tuple)
+                return tuple(b.detach() if b is not None else b for b in static_grad_inputs)
+
+        def functionalized(*user_args):
+            # Runs the autograd function with inputs == all inputs to the graph that might require grad
+            # (explicit user args + module parameters)
+            # Assumes module params didn't change since capture.
+            out = Graphed.apply(*(user_args + module_params))
+            return out[0] if output_was_tensor else out
+
+        return functionalized
+
+    # Put together the final graphed callables
+    ret = []
+    for i, func in enumerate(callables):
+        graphed = make_graphed_autograd_function(fwd_graphs[i],
+                                                 bwd_graphs[i],
+                                                 per_callable_module_params[i],
+                                                 per_callable_len_user_args[i],
+                                                 per_callable_output_was_tensor[i],
+                                                 per_callable_static_input_surfaces[i],
+                                                 per_callable_static_outputs[i],
+                                                 per_callable_static_grad_outputs[i],
+                                                 per_callable_static_grad_inputs[i])
+
+        if isinstance(func, torch.nn.Module):
+            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
+                def new_fwd(*user_args):
+                    # If the module's training-or-eval state matches what we graphed,
+                    # run the graph, otherwise run the original forward method
+                    if func.training == graph_training_state:
+                        return graphed(*user_args)
+                    else:
+                        return orig_fwd(*user_args)
+                return new_fwd
+            func.forward = make_graphed_forward(func, func.training, graphed, func.forward)  # type: ignore[assignment]
+            ret.append(func)
+        else:
+            ret.append(graphed)
+
+    if just_one_callable:
+        return ret[0]
+
+    return tuple(ret)
index 0f98372..2b4cc47 100644 (file)
@@ -8,8 +8,6 @@ if not hasattr(torch._C, '_CudaStreamBase'):
     # Define dummy base classes
     torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase')
     torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase')
-    torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase')
-    torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
 
 class Stream(torch._C._CudaStreamBase):
     r"""Wrapper around a CUDA stream.
@@ -226,6 +224,3 @@ class Event(torch._C._CudaEventBase):
             return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
         else:
             return '<torch.cuda.Event uninitialized>'
-
-_Graph = torch._C._CudaGraphBase
-_graph_pool_handle = torch._C._graph_pool_handle