Modules note v2 (#63963)
authorJoel Schlosser <jbschlosser@fb.com>
Fri, 27 Aug 2021 18:28:03 +0000 (11:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 18:30:18 +0000 (11:30 -0700)
Summary:
This PR expands the [note on modules](https://pytorch.org/docs/stable/notes/modules.html) with additional info for 1.10.

It adds the following:
* Examples of using hooks
* Examples of using apply()
* Examples for ParameterList / ParameterDict
* register_parameter() / register_buffer() usage
* Discussion of train() / eval() modes
* Distributed training overview / links
* TorchScript overview / links
* Quantization overview / links
* FX overview / links
* Parametrization overview / link to tutorial

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63963

Reviewed By: albanD

Differential Revision: D30606604

Pulled By: jbschlosser

fbshipit-source-id: c1030b19162bcb5fe7364bcdc981a2eb6d6e89b4

docs/source/notes/modules.rst

index 4eba022..c1d978d 100644 (file)
@@ -117,7 +117,7 @@ multiple modules:
 
 Note that :class:`~torch.nn.Sequential` automatically feeds the output of the first ``MyLinear`` module as input
 into the :class:`~torch.nn.ReLU`, and the output of that as input into the second ``MyLinear`` module. As
-shown, it is limited to in-order chaining of modules.
+shown, it is limited to in-order chaining of modules with a single input and output.
 
 In general, it is recommended to define a custom module for anything beyond the simplest use cases, as this gives
 full flexibility on how submodules are used for a module's computation.
@@ -258,16 +258,32 @@ It's also easy to move all parameters to a different device or change their prec
    dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64))
    : tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
 
-These examples show how elaborate neural networks can be formed through module composition. To allow for
-quick and easy construction of neural networks with minimal boilerplate, PyTorch provides a large library of
-performant modules within the :mod:`torch.nn` namespace that perform computation commonly found within neural
-networks, including pooling, convolutions, loss functions, etc.
+More generally, an arbitrary function can be applied to a module and its submodules recursively by
+using the :func:`~torch.nn.Module.apply` function. For example, to apply custom initialization to parameters
+of a module and its submodules:
+
+.. code-block:: python
+
+   # Define a function to initialize Linear weights.
+   # Note that no_grad() is used here to avoid tracking this computation in the autograd graph.
+   @torch.no_grad()
+   def init_weights(m):
+     if isinstance(m, nn.Linear):
+       nn.init.xavier_normal_(m.weight)
+       m.bias.fill_(0.0)
+
+   # Apply the function recursively on the module and its submodules.
+   dynamic_net.apply(init_weights)
+
+These examples show how elaborate neural networks can be formed through module composition and conveniently
+manipulated. To allow for quick and easy construction of neural networks with minimal boilerplate, PyTorch
+provides a large library of performant modules within the :mod:`torch.nn` namespace that perform common neural
+network operations like pooling, convolutions, loss functions, etc.
 
 In the next section, we give a full example of training a neural network.
 
 For more information, check out:
 
-* Recursively :func:`~torch.nn.Module.apply` a function to a module and its submodules
 * Library of PyTorch-provided modules: `torch.nn <https://pytorch.org/docs/stable/nn.html>`_
 * Defining neural net modules: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_module.html
 
@@ -295,6 +311,12 @@ Optimizers from :mod:`torch.optim`:
      loss.backward()
      optimizer.step()
 
+   # After training, switch the module to eval mode to do inference, compute performance metrics, etc.
+   # (see discussion below for a description of training and evaluation modes)
+   ...
+   net.eval()
+   ...
+
 In this simplified example, the network learns to simply output zero, as any non-zero output is "penalized" according
 to its absolute value by employing :func:`torch.abs` as a loss function. While this is not a very interesting task, the
 key parts of training are present:
@@ -321,6 +343,38 @@ value of ``l1``\ 's ``weight`` parameter shows that its values are now much clos
            [ 0.0030],
            [-0.0008]], requires_grad=True)
 
+Note that the above process is done entirely while the network module is in "training mode". Modules default to
+training mode and can be switched between training and evaluation modes using :func:`~torch.nn.Module.train` and
+:func:`~torch.nn.Module.eval`. They can behave differently depending on which mode they are in. For example, the
+:class:`~torch.nn.BatchNorm` module maintains a running mean and variance during training that are not updated
+when the module is in evaluation mode. In general, modules should be in training mode during training
+and only switched to evaluation mode for inference or evaluation. Below is an example of a custom module
+that behaves differently between the two modes:
+
+.. code-block:: python
+
+   class ModalModule(nn.Module):
+     def __init__(self):
+       super().__init__()
+
+     def forward(self, x):
+       if self.training:
+         # Add a constant only in training mode.
+         return x + 1.
+       else:
+         return x
+
+
+   m = ModalModule()
+   x = torch.randn(4)
+
+   print('training mode output: {}'.format(m(x)))
+   : tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481])
+
+   m.eval()
+   print('evaluation mode output: {}'.format(m(x)))
+   : tensor([ 0.6614,  0.2669,  0.0617,  0.6213, -0.4519])
+
 Training neural networks can often be tricky. For more information, check out:
 
 * Using Optimizers: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_optim.html.
@@ -409,12 +463,127 @@ Both persistent and non-persistent buffers are affected by model-wide device / d
 Buffers of a module can be iterated over using :func:`~torch.nn.Module.buffers` or
 :func:`~torch.nn.Module.named_buffers`.
 
+.. code-block:: python
+
+   for buffer in m.named_buffers():
+     print(buffer)
+
+The following class demonstrates the various ways of registering parameters and buffers within a module:
+
+.. code-block:: python
+
+   class StatefulModule(nn.Module):
+     def __init__(self):
+       super().__init__()
+       # Setting a nn.Parameter as an attribute of the module automatically registers the tensor
+       # as a parameter of the module.
+       self.param1 = nn.Parameter(torch.randn(2))
+
+       # Alternative string-based way to register a parameter.
+       self.register_parameter('param2', nn.Parameter(torch.randn(3)))
+
+       # Reserves the "param3" attribute as a parameter, preventing it from being set to anything
+       # except a parameter. "None" entries like this will not be present in the module's state_dict.
+       self.register_parameter('param3', None)
+
+       # Registers a list of parameters.
+       self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)])
+
+       # Registers a dictionary of parameters.
+       self.param_dict = nn.ParameterDict({
+         'foo': nn.Parameter(torch.randn(3)),
+         'bar': nn.Parameter(torch.randn(4))
+       })
+
+       # Registers a persistent buffer (one that appears in the module's state_dict).
+       self.register_buffer('buffer1', torch.randn(4), persistent=True)
+
+       # Registers a non-persistent buffer (one that does not appear in the module's state_dict).
+       self.register_buffer('buffer2', torch.randn(5), persistent=False)
+
+       # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything
+       # except a buffer. "None" entries like this will not be present in the module's state_dict.
+       self.register_buffer('buffer3', None)
+
+       # Adding a submodule registers its parameters as parameters of the module.
+       self.linear = nn.Linear(2, 3)
+
+   m = StatefulModule()
+
+   # Save and load state_dict.
+   torch.save(m.state_dict(), 'state.pt')
+   m_loaded = StatefulModule()
+   m_loaded.load_state_dict(torch.load('state.pt'))
+
+   # Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do
+   # not appear in the state_dict.
+   print(m_loaded.state_dict())
+   : OrderedDict([('param1', tensor([-0.0322,  0.9066])),
+                  ('param2', tensor([-0.4472,  0.1409,  0.4852])),
+                  ('buffer1', tensor([ 0.6949, -0.1944,  1.2911, -2.1044])),
+                  ('param_list.0', tensor([ 0.4202, -0.1953])),
+                  ('param_list.1', tensor([ 1.5299, -0.8747])),
+                  ('param_list.2', tensor([-1.6289,  1.4898])),
+                  ('param_dict.bar', tensor([-0.6434,  1.5187,  0.0346, -0.4077])),
+                  ('param_dict.foo', tensor([-0.0845, -1.4324,  0.7022])),
+                  ('linear.weight', tensor([[-0.3915, -0.6176],
+                                            [ 0.6062, -0.5992],
+                                            [ 0.4452, -0.2843]])),
+                  ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))])
+
 For more information, check out:
 
 * Saving and loading: https://pytorch.org/tutorials/beginner/saving_loading_models.html
 * Serialization semantics: https://pytorch.org/docs/master/notes/serialization.html
 * What is a state dict? https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html
 
+Module Initialization
+---------------------
+
+By default, parameters and floating-point buffers for modules provided by :mod:`torch.nn` are initialized during
+module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to
+perform well historically for the module type. For certain use cases, it may be desired to initialize with a different
+dtype, device (e.g. GPU), or initialization technique.
+
+Examples:
+
+.. code-block:: python
+
+   # Initialize module directly onto GPU.
+   m = nn.Linear(5, 3, device='cuda')
+
+   # Initialize module with 16-bit floating point parameters.
+   m = nn.Linear(5, 3, dtype=torch.half)
+
+   # Skip default parameter initialization and perform custom (e.g. orthogonal) initialization.
+   m = torch.nn.utils.skip_init(nn.Linear, 5, 3)
+   nn.init.orthogonal_(m.weight)
+
+Note that the device and dtype options demonstrated above also apply to any floating-point buffers registered
+for the module:
+
+.. code-block:: python
+
+   m = nn.BatchNorm2d(3, dtype=torch.half)
+   print(m.running_mean)
+   : tensor([0., 0., 0.], dtype=torch.float16)
+
+While module writers can use any device or dtype to initialize parameters in their custom modules, good practice is
+to use ``dtype=torch.float`` and ``device='cpu'`` by default as well. Optionally, you can provide full flexibility
+in these areas for your custom module by conforming to the convention demonstrated above that all
+:mod:`torch.nn` modules follow:
+
+* Provide a ``device`` constructor kwarg that applies to any parameters / buffers registered by the module.
+* Provide a ``dtype`` constructor kwarg that applies to any parameters / floating-point buffers registered by
+  the module.
+* Only use initialization functions (i.e. functions from :mod:`torch.nn.init`) on parameters and buffers within the
+  module's constructor. Note that this is only required to use :func:`~torch.nn.utils.skip_init`; see
+  `this page <https://pytorch.org/tutorials/prototype/skip_param_init.html#updating-modules-to-support-skipping-initialization>`_ for an explanation.
+
+For more information, check out:
+
+* Skipping module parameter initialization: https://pytorch.org/tutorials/prototype/skip_param_init.html
+
 Module Hooks
 ------------
 
@@ -443,16 +612,137 @@ All hooks allow the user to return an updated value that will be used throughout
 Thus, these hooks can be used to either execute arbitrary code along the regular module forward/backward or
 modify some inputs/outputs without having to change the module's ``forward()`` function.
 
+Below is an example demonstrating usage of forward and backward hooks:
+
+.. code-block:: python
+
+   torch.manual_seed(1)
+
+   def forward_pre_hook(m, inputs):
+     # Allows for examination and modification of the input before the forward pass.
+     # Note that inputs are always wrapped in a tuple.
+     input = inputs[0]
+     return input + 1.
+
+   def forward_hook(m, inputs, output):
+     # Allows for examination of inputs / outputs and modification of the outputs
+     # after the forward pass. Note that inputs are always wrapped in a tuple while outputs
+     # are passed as-is.
+
+     # Residual computation a la ResNet.
+     return output + inputs[0]
+
+   def backward_hook(m, grad_inputs, grad_outputs):
+     # Allows for examination of grad_inputs / grad_outputs and modification of
+     # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and
+     # grad_outputs are always wrapped in tuples.
+     new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
+     return new_grad_inputs
+
+   # Create sample module & input.
+   m = nn.Linear(3, 3)
+   x = torch.randn(2, 3, requires_grad=True)
+
+   # ==== Demonstrate forward hooks. ====
+   # Run input through module before and after adding hooks.
+   print('output with no forward hooks: {}'.format(m(x)))
+   : output with no forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
+                                           [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)
+
+   # Note that the modified input results in a different output.
+   forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
+   print('output with forward pre hook: {}'.format(m(x)))
+   : output with forward pre hook: tensor([[-0.5752, -0.7421,  0.4942],
+                                           [-0.0736,  0.5461,  0.0838]], grad_fn=<AddmmBackward>)
+
+   # Note the modified output.
+   forward_hook_handle = m.register_forward_hook(forward_hook)
+   print('output with both forward hooks: {}'.format(m(x)))
+   : output with both forward hooks: tensor([[-1.0980,  0.6396,  0.4666],
+                                             [ 0.3634,  0.6538,  1.0256]], grad_fn=<AddBackward0>)
+
+   # Remove hooks; note that the output here matches the output before adding hooks.
+   forward_pre_hook_handle.remove()
+   forward_hook_handle.remove()
+   print('output after removing forward hooks: {}'.format(m(x)))
+   : output after removing forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
+                                                  [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)
+
+   # ==== Demonstrate backward hooks. ====
+   m(x).sum().backward()
+   print('x.grad with no backwards hook: {}'.format(x.grad))
+   : x.grad with no backwards hook: tensor([[ 0.4497, -0.5046,  0.3146],
+                                            [ 0.4497, -0.5046,  0.3146]])
+
+   # Clear gradients before running backward pass again.
+   m.zero_grad()
+   x.grad.zero_()
+
+   m.register_full_backward_hook(backward_hook)
+   m(x).sum().backward()
+   print('x.grad with backwards hook: {}'.format(x.grad))
+   : x.grad with backwards hook: tensor([[42., 42., 42.],
+                                         [42., 42., 42.]])
+
 Advanced Features
 -----------------
 
 PyTorch also provides several more advanced features that are designed to work with modules. All these functionalities
-are "inherited" when writing a new module. In-depth discussion of these features can be found in the links below.
+are available for custom-written modules, with the small caveat that certain features may require modules to conform
+to particular constraints in order to be supported. In-depth discussion of these features and the corresponding
+requirements can be found in the links below.
 
-For more information, check out:
+Distributed Training
+********************
+
+Various methods for distributed training exist within PyTorch, both for scaling up training using multiple GPUs
+as well as training across multiple machines. Check out the
+`distributed training overview page <https://pytorch.org/tutorials/beginner/dist_overview.html>`_ for
+detailed information on how to utilize these.
+
+Profiling Performance
+*********************
+
+The `PyTorch Profiler <https://pytorch.org/tutorials/beginner/profiler.html>`_ can be useful for identifying
+performance bottlenecks within your models. It measures and outputs performance characteristics for
+both memory usage and time spent.
+
+Improving Performance with Quantization
+***************************************
+
+Applying quantization techniques to modules can improve performance and memory usage by utilizing lower
+bitwidths than floating-point precision. Check out the various PyTorch-provided mechanisms for quantization
+`here <https://pytorch.org/docs/stable/quantization.html>`_.
+
+Improving Memory Usage with Pruning
+***********************************
+
+Large deep learning models are often over-parametrized, resulting in high memory usage. To combat this, PyTorch
+provides mechanisms for model pruning, which can help reduce memory usage while maintaining task accuracy. The
+`Pruning tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_ describes how to utilize
+the pruning techniques PyTorch provides or define custom pruning techniques as necessary.
+
+Deploying with TorchScript
+**************************
+
+When deploying a model for use in production, the overhead of Python can be unacceptable due to its poor
+performance characteristics. For cases like this,
+`TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ provides a way to load
+and run an optimized model program from outside of Python, such as within a C++ program.
+
+Parametrizations
+****************
+
+For certain applications, it can be beneficial to constrain the parameter space during model training. For example,
+enforcing orthogonality of the learned parameters can improve convergence for RNNs. PyTorch provides a mechanism for
+applying `parametrizations <https://pytorch.org/tutorials/intermediate/parametrizations.html>`_ such as this, and
+further allows for custom constraints to be defined.
+
+Transforming Modules with FX
+****************************
 
-* Profiling: https://pytorch.org/tutorials/beginner/profiler.html
-* Pruning: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
-* Quantization: https://pytorch.org/tutorials/recipes/quantization.html
-* Exporting modules to TorchScript (e.g. for usage from C++):
-  https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
+The `FX <https://pytorch.org/docs/stable/fx.html>`_ component of PyTorch provides a flexible way to transform
+modules by operating directly on module computation graphs. This can be used to programmatically generate or
+manipulate modules for a broad array of use cases. To explore FX, check out these examples of using FX for
+`convolution + batch norm fusion <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_ and
+`CPU performance analysis <https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html>`_.