Add docs describing saved tensor hooks (#62362)
authorVictor Quach <quach@fb.com>
Fri, 20 Aug 2021 18:07:22 +0000 (11:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 18:10:51 +0000 (11:10 -0700)
Summary:
Add section to the Autograd mechanics docs to describe the recently
exposed saved tensors (https://github.com/pytorch/pytorch/issues/52451), how to register packing / unpacking
hooks (https://github.com/pytorch/pytorch/issues/60975) and how to use default hooks (https://github.com/pytorch/pytorch/issues/61834)

Sister PR: https://github.com/pytorch/pytorch/issues/62361 (will add a link from autograd.rst to notes/autograd in whatever PR does not land first)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62362

Reviewed By: soulitzer

Differential Revision: D30453177

Pulled By: Varal7

fbshipit-source-id: f5759977b069ff0ef36a47b08856d297691a6caa

docs/source/autograd.rst
docs/source/notes/autograd.rst

index 6423d5d..8aace1e 100644 (file)
@@ -252,6 +252,7 @@ You can define how these saved tensors should be packed / unpacked using hooks.
 A common application is to trade compute for memory by saving those intermediary results
 to disk or to CPU instead of leaving them on the GPU. This is especially useful if you
 notice your model fits on GPU during evaluation, but not training.
+Also see :ref:`saved-tensors-hooks-doc`.
 
 .. autoclass:: torch.autograd.graph.saved_tensors_hooks
 
index 0c1eed3..2a59d97 100644 (file)
@@ -36,6 +36,57 @@ flow statements, that can change the overall shape and size of the graph at
 every iteration. You don't have to encode all possible paths before you
 launch the training - what you run is what you differentiate.
 
+.. _saved-tensors-doc:
+
+Saved tensors
+^^^^^^^^^^^^^
+
+Some operations need intermediary results to be saved during the forward pass
+in order to execute the backward pass. For example, the function
+:math:`x\mapsto x^2` saves the input :math:`x` to compute the gradient.
+
+When defining a custom Python :class:`~torch.autograd.Function`, you can use
+:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` to save
+tensors during the forward pass and
+:attr:`~torch.autograd.function.Function.saved_tensors` to retrieve them
+during the backward pass. See :doc:`/notes/extending` for more information.
+
+For operations that PyTorch defines (e.g. :func:`torch.pow`), tensors are
+automatically saved as needed. You can explore (for educational or debugging
+purposes) which tensors are saved by a certain ``grad_fn`` by looking for its
+attributes starting with the prefix ``_saved``.
+
+.. code::
+
+    x = torch.randn(5, requires_grad=True)
+    y = x.pow(2)
+    print(x.equal(y.grad_fn._saved_self))  # True
+    print(x is y.grad_fn._saved_self)  # True
+
+
+In the previous code, ``y.grad_fn._saved_self`` refers to the same Tensor object as `x`.
+But that may not always be the case. For instance:
+
+.. code::
+
+    x = torch.randn(5, requires_grad=True)
+    y = x.exp()
+    print(y.equal(y.grad_fn._saved_result))  # True
+    print(y is y.grad_fn._saved_result)  # False
+
+
+Under the hood, to prevent reference cycles, PyTorch has *packed* the tensor
+upon saving and *unpacked* it into a different tensor for reading. Here, the
+tensor you get from accessing ``y.grad_fn._saved_result`` is a different tensor
+object than ``x`` (but they still share the same storage).
+
+Whether a tensor will be packed into a different tensor object depends on
+whether it is an output of its own `grad_fn`, which is an implementation detail
+subject to change and that users should not rely on.
+
+You can control how PyTorch does packing / unpacking with :ref:`saved-tensors-hooks-doc`.
+
+
 .. _locally-disable-grad-doc:
 
 Locally disabling gradient computation
@@ -598,3 +649,151 @@ chain rule:
 
         .. math::
             \frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})
+
+.. _saved-tensors-hooks-doc:
+
+Hooks for saved tensors
+-----------------------
+
+You can control :ref:`how saved tensors are packed / unpacked
+<saved-tensors-doc>` by defining a pair of ``pack_hook`` / ``unpack_hook``
+hooks.  The ``pack_hook`` function should take a tensor as its single argument
+but can return any python object (e.g. another tensor, a tuple, or even a
+string containing a filename). The ``unpack_hook`` function takes as its single
+argument the output of ``pack_hook`` and should return a tensor to be used in
+the backward pass. The tensor returned by ``unpack_hook`` only needs to have
+the same content as the tensor passed as input to ``pack_hook``. In particular,
+any autograd-related metadata can be ignored as they will be overwritten during
+unpacking.
+
+An example of such pair is:
+
+.. code::
+
+    class SelfDeletingTempFile():
+        def __init__(self):
+            self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
+
+        def __del__(self):
+            os.remove(self.name)
+
+    def pack_hook(tensor):
+        temp_file = SelfDeletingTempFile()
+        torch.save(tensor, temp_file.name)
+        return temp_file
+
+    def unpack_hook(temp_file):
+        return torch.load(temp_file.name)
+
+Notice that the ``unpack_hook`` should not delete the temporary file because it
+might be called multiple times: the temporary file should be alive for as long
+as the returned `SelfDeletingTempFile` object is alive.  In the above example,
+we prevent leaking the temporary file by closing it when it is no longer needed
+(on deletion of the `SelfDeletingTempFile` object).
+
+.. note::
+
+    We guarantee that ``pack_hook`` will only be called once but ``unpack_hook`` can
+    be called as many times as the backward pass requires it and we expect it to
+    return the same data each time.
+
+.. warning::
+
+    Performing inplace operations on the input of any of the functions is forbidden
+    as they may lead to unexpected side-effects. PyTorch will throw an error if the
+    input to a pack hook is modified inplace but does not catch the case where the
+    input to an unpack hook is modified inplace.
+
+
+Registering hooks for a saved tensor
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+You can register a pair of hooks on a saved tensor by calling the
+:meth:`~torch.autograd.SavedTensor.register_hooks` method on a
+:class:`SavedTensor` object. Those objects are exposed as attributes of a
+``grad_fn`` and start with the ``_raw_saved_`` prefix.
+
+.. code::
+
+    x = torch.randn(5, requires_grad=True)
+    y = x.pow(2)
+    y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
+
+The ``pack_hook`` method is called as soon as the pair is registered.
+The ``unpack_hook`` method is called each time the saved tensor needs to be
+accessed, either by means of ``y.grad_fn._saved_self`` or during the backward
+pass.
+
+.. warning::
+
+    If you maintain a reference to a :class:`SavedTensor` after the saved
+    tensors have been released (i.e. after backward has been called), calling
+    its :meth:`~torch.autograd.SavedTensor.register_hooks` is forbidden.
+    PyTorch will throw an error most of the time but it may fail
+    to do so in some cases and undefined behavior may arise.
+
+Registering default hooks for saved tensors
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Alternatively, you can use the context-manager
+:class:`~torch.autograd.graph.saved_tensors_hooks` to register a pair of
+hooks which will be applied to *all* saved tensors that are created in
+that context.
+
+Example:
+
+.. code::
+
+    # Only save on disk tensors that have size >= 1000
+    SAVE_ON_DISK_THRESHOLD = 1000
+
+    def pack_hook(x):
+        if x.numel() < SAVE_ON_DISK_THRESHOLD:
+            return x
+        temp_file = SelfDeletingTempFile()
+        torch.save(tensor, temp_file.name)
+        return temp_file
+
+    def unpack_hook(tensor_or_sctf):
+        if isinstance(tensor_or_sctf, torch.Tensor):
+            return tensor_or_sctf
+        return torch.load(tensor_or_sctf.name)
+
+    class Model(nn.Module):
+        def forward(self, x):
+            with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+              # ... compute output
+              output = x
+            return output
+
+    model = Model()
+    net = nn.DataParallel(model)
+
+
+
+The hooks defined with this context manager are thread-local.
+Hence, the following code will not produce the desired effects because the hooks do not go
+through `DataParallel`.
+
+.. code::
+
+      # Example what NOT to do
+
+      net = nn.DataParallel(model)
+      with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+          output = net(input)
+
+
+Note that using those hooks disables all the optimization in place to reduce
+Tensor object creation. For example:
+
+.. code::
+
+    with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+        x = torch.randn(5, requires_grad=True)
+        y = x * x
+
+Without the hooks, ``x``, ``y.grad_fn._saved_self`` and
+``y.grad_fn._saved_other`` all refer to the same tensor object.
+With the hooks, PyTorch will pack and unpack `x` into two new tensor objects
+that share the same storage with the original `x` (no copy performed).