every iteration. You don't have to encode all possible paths before you
launch the training - what you run is what you differentiate.
+.. _saved-tensors-doc:
+
+Saved tensors
+^^^^^^^^^^^^^
+
+Some operations need intermediary results to be saved during the forward pass
+in order to execute the backward pass. For example, the function
+:math:`x\mapsto x^2` saves the input :math:`x` to compute the gradient.
+
+When defining a custom Python :class:`~torch.autograd.Function`, you can use
+:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` to save
+tensors during the forward pass and
+:attr:`~torch.autograd.function.Function.saved_tensors` to retrieve them
+during the backward pass. See :doc:`/notes/extending` for more information.
+
+For operations that PyTorch defines (e.g. :func:`torch.pow`), tensors are
+automatically saved as needed. You can explore (for educational or debugging
+purposes) which tensors are saved by a certain ``grad_fn`` by looking for its
+attributes starting with the prefix ``_saved``.
+
+.. code::
+
+ x = torch.randn(5, requires_grad=True)
+ y = x.pow(2)
+ print(x.equal(y.grad_fn._saved_self)) # True
+ print(x is y.grad_fn._saved_self) # True
+
+
+In the previous code, ``y.grad_fn._saved_self`` refers to the same Tensor object as `x`.
+But that may not always be the case. For instance:
+
+.. code::
+
+ x = torch.randn(5, requires_grad=True)
+ y = x.exp()
+ print(y.equal(y.grad_fn._saved_result)) # True
+ print(y is y.grad_fn._saved_result) # False
+
+
+Under the hood, to prevent reference cycles, PyTorch has *packed* the tensor
+upon saving and *unpacked* it into a different tensor for reading. Here, the
+tensor you get from accessing ``y.grad_fn._saved_result`` is a different tensor
+object than ``x`` (but they still share the same storage).
+
+Whether a tensor will be packed into a different tensor object depends on
+whether it is an output of its own `grad_fn`, which is an implementation detail
+subject to change and that users should not rely on.
+
+You can control how PyTorch does packing / unpacking with :ref:`saved-tensors-hooks-doc`.
+
+
.. _locally-disable-grad-doc:
Locally disabling gradient computation
.. math::
\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})
+
+.. _saved-tensors-hooks-doc:
+
+Hooks for saved tensors
+-----------------------
+
+You can control :ref:`how saved tensors are packed / unpacked
+<saved-tensors-doc>` by defining a pair of ``pack_hook`` / ``unpack_hook``
+hooks. The ``pack_hook`` function should take a tensor as its single argument
+but can return any python object (e.g. another tensor, a tuple, or even a
+string containing a filename). The ``unpack_hook`` function takes as its single
+argument the output of ``pack_hook`` and should return a tensor to be used in
+the backward pass. The tensor returned by ``unpack_hook`` only needs to have
+the same content as the tensor passed as input to ``pack_hook``. In particular,
+any autograd-related metadata can be ignored as they will be overwritten during
+unpacking.
+
+An example of such pair is:
+
+.. code::
+
+ class SelfDeletingTempFile():
+ def __init__(self):
+ self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
+
+ def __del__(self):
+ os.remove(self.name)
+
+ def pack_hook(tensor):
+ temp_file = SelfDeletingTempFile()
+ torch.save(tensor, temp_file.name)
+ return temp_file
+
+ def unpack_hook(temp_file):
+ return torch.load(temp_file.name)
+
+Notice that the ``unpack_hook`` should not delete the temporary file because it
+might be called multiple times: the temporary file should be alive for as long
+as the returned `SelfDeletingTempFile` object is alive. In the above example,
+we prevent leaking the temporary file by closing it when it is no longer needed
+(on deletion of the `SelfDeletingTempFile` object).
+
+.. note::
+
+ We guarantee that ``pack_hook`` will only be called once but ``unpack_hook`` can
+ be called as many times as the backward pass requires it and we expect it to
+ return the same data each time.
+
+.. warning::
+
+ Performing inplace operations on the input of any of the functions is forbidden
+ as they may lead to unexpected side-effects. PyTorch will throw an error if the
+ input to a pack hook is modified inplace but does not catch the case where the
+ input to an unpack hook is modified inplace.
+
+
+Registering hooks for a saved tensor
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+You can register a pair of hooks on a saved tensor by calling the
+:meth:`~torch.autograd.SavedTensor.register_hooks` method on a
+:class:`SavedTensor` object. Those objects are exposed as attributes of a
+``grad_fn`` and start with the ``_raw_saved_`` prefix.
+
+.. code::
+
+ x = torch.randn(5, requires_grad=True)
+ y = x.pow(2)
+ y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
+
+The ``pack_hook`` method is called as soon as the pair is registered.
+The ``unpack_hook`` method is called each time the saved tensor needs to be
+accessed, either by means of ``y.grad_fn._saved_self`` or during the backward
+pass.
+
+.. warning::
+
+ If you maintain a reference to a :class:`SavedTensor` after the saved
+ tensors have been released (i.e. after backward has been called), calling
+ its :meth:`~torch.autograd.SavedTensor.register_hooks` is forbidden.
+ PyTorch will throw an error most of the time but it may fail
+ to do so in some cases and undefined behavior may arise.
+
+Registering default hooks for saved tensors
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Alternatively, you can use the context-manager
+:class:`~torch.autograd.graph.saved_tensors_hooks` to register a pair of
+hooks which will be applied to *all* saved tensors that are created in
+that context.
+
+Example:
+
+.. code::
+
+ # Only save on disk tensors that have size >= 1000
+ SAVE_ON_DISK_THRESHOLD = 1000
+
+ def pack_hook(x):
+ if x.numel() < SAVE_ON_DISK_THRESHOLD:
+ return x
+ temp_file = SelfDeletingTempFile()
+ torch.save(tensor, temp_file.name)
+ return temp_file
+
+ def unpack_hook(tensor_or_sctf):
+ if isinstance(tensor_or_sctf, torch.Tensor):
+ return tensor_or_sctf
+ return torch.load(tensor_or_sctf.name)
+
+ class Model(nn.Module):
+ def forward(self, x):
+ with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+ # ... compute output
+ output = x
+ return output
+
+ model = Model()
+ net = nn.DataParallel(model)
+
+
+
+The hooks defined with this context manager are thread-local.
+Hence, the following code will not produce the desired effects because the hooks do not go
+through `DataParallel`.
+
+.. code::
+
+ # Example what NOT to do
+
+ net = nn.DataParallel(model)
+ with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+ output = net(input)
+
+
+Note that using those hooks disables all the optimization in place to reduce
+Tensor object creation. For example:
+
+.. code::
+
+ with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+ x = torch.randn(5, requires_grad=True)
+ y = x * x
+
+Without the hooks, ``x``, ``y.grad_fn._saved_self`` and
+``y.grad_fn._saved_other`` all refer to the same tensor object.
+With the hooks, PyTorch will pack and unpack `x` into two new tensor objects
+that share the same storage with the original `x` (no copy performed).