Note that :class:`~torch.nn.Sequential` automatically feeds the output of the first ``MyLinear`` module as input
into the :class:`~torch.nn.ReLU`, and the output of that as input into the second ``MyLinear`` module. As
-shown, it is limited to in-order chaining of modules.
+shown, it is limited to in-order chaining of modules with a single input and output.
In general, it is recommended to define a custom module for anything beyond the simplest use cases, as this gives
full flexibility on how submodules are used for a module's computation.
dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64))
: tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
-These examples show how elaborate neural networks can be formed through module composition. To allow for
-quick and easy construction of neural networks with minimal boilerplate, PyTorch provides a large library of
-performant modules within the :mod:`torch.nn` namespace that perform computation commonly found within neural
-networks, including pooling, convolutions, loss functions, etc.
+More generally, an arbitrary function can be applied to a module and its submodules recursively by
+using the :func:`~torch.nn.Module.apply` function. For example, to apply custom initialization to parameters
+of a module and its submodules:
+
+.. code-block:: python
+
+ # Define a function to initialize Linear weights.
+ # Note that no_grad() is used here to avoid tracking this computation in the autograd graph.
+ @torch.no_grad()
+ def init_weights(m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ m.bias.fill_(0.0)
+
+ # Apply the function recursively on the module and its submodules.
+ dynamic_net.apply(init_weights)
+
+These examples show how elaborate neural networks can be formed through module composition and conveniently
+manipulated. To allow for quick and easy construction of neural networks with minimal boilerplate, PyTorch
+provides a large library of performant modules within the :mod:`torch.nn` namespace that perform common neural
+network operations like pooling, convolutions, loss functions, etc.
In the next section, we give a full example of training a neural network.
For more information, check out:
-* Recursively :func:`~torch.nn.Module.apply` a function to a module and its submodules
* Library of PyTorch-provided modules: `torch.nn <https://pytorch.org/docs/stable/nn.html>`_
* Defining neural net modules: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_module.html
loss.backward()
optimizer.step()
+ # After training, switch the module to eval mode to do inference, compute performance metrics, etc.
+ # (see discussion below for a description of training and evaluation modes)
+ ...
+ net.eval()
+ ...
+
In this simplified example, the network learns to simply output zero, as any non-zero output is "penalized" according
to its absolute value by employing :func:`torch.abs` as a loss function. While this is not a very interesting task, the
key parts of training are present:
[ 0.0030],
[-0.0008]], requires_grad=True)
+Note that the above process is done entirely while the network module is in "training mode". Modules default to
+training mode and can be switched between training and evaluation modes using :func:`~torch.nn.Module.train` and
+:func:`~torch.nn.Module.eval`. They can behave differently depending on which mode they are in. For example, the
+:class:`~torch.nn.BatchNorm` module maintains a running mean and variance during training that are not updated
+when the module is in evaluation mode. In general, modules should be in training mode during training
+and only switched to evaluation mode for inference or evaluation. Below is an example of a custom module
+that behaves differently between the two modes:
+
+.. code-block:: python
+
+ class ModalModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ if self.training:
+ # Add a constant only in training mode.
+ return x + 1.
+ else:
+ return x
+
+
+ m = ModalModule()
+ x = torch.randn(4)
+
+ print('training mode output: {}'.format(m(x)))
+ : tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481])
+
+ m.eval()
+ print('evaluation mode output: {}'.format(m(x)))
+ : tensor([ 0.6614, 0.2669, 0.0617, 0.6213, -0.4519])
+
Training neural networks can often be tricky. For more information, check out:
* Using Optimizers: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_optim.html.
Buffers of a module can be iterated over using :func:`~torch.nn.Module.buffers` or
:func:`~torch.nn.Module.named_buffers`.
+.. code-block:: python
+
+ for buffer in m.named_buffers():
+ print(buffer)
+
+The following class demonstrates the various ways of registering parameters and buffers within a module:
+
+.. code-block:: python
+
+ class StatefulModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # Setting a nn.Parameter as an attribute of the module automatically registers the tensor
+ # as a parameter of the module.
+ self.param1 = nn.Parameter(torch.randn(2))
+
+ # Alternative string-based way to register a parameter.
+ self.register_parameter('param2', nn.Parameter(torch.randn(3)))
+
+ # Reserves the "param3" attribute as a parameter, preventing it from being set to anything
+ # except a parameter. "None" entries like this will not be present in the module's state_dict.
+ self.register_parameter('param3', None)
+
+ # Registers a list of parameters.
+ self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)])
+
+ # Registers a dictionary of parameters.
+ self.param_dict = nn.ParameterDict({
+ 'foo': nn.Parameter(torch.randn(3)),
+ 'bar': nn.Parameter(torch.randn(4))
+ })
+
+ # Registers a persistent buffer (one that appears in the module's state_dict).
+ self.register_buffer('buffer1', torch.randn(4), persistent=True)
+
+ # Registers a non-persistent buffer (one that does not appear in the module's state_dict).
+ self.register_buffer('buffer2', torch.randn(5), persistent=False)
+
+ # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything
+ # except a buffer. "None" entries like this will not be present in the module's state_dict.
+ self.register_buffer('buffer3', None)
+
+ # Adding a submodule registers its parameters as parameters of the module.
+ self.linear = nn.Linear(2, 3)
+
+ m = StatefulModule()
+
+ # Save and load state_dict.
+ torch.save(m.state_dict(), 'state.pt')
+ m_loaded = StatefulModule()
+ m_loaded.load_state_dict(torch.load('state.pt'))
+
+ # Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do
+ # not appear in the state_dict.
+ print(m_loaded.state_dict())
+ : OrderedDict([('param1', tensor([-0.0322, 0.9066])),
+ ('param2', tensor([-0.4472, 0.1409, 0.4852])),
+ ('buffer1', tensor([ 0.6949, -0.1944, 1.2911, -2.1044])),
+ ('param_list.0', tensor([ 0.4202, -0.1953])),
+ ('param_list.1', tensor([ 1.5299, -0.8747])),
+ ('param_list.2', tensor([-1.6289, 1.4898])),
+ ('param_dict.bar', tensor([-0.6434, 1.5187, 0.0346, -0.4077])),
+ ('param_dict.foo', tensor([-0.0845, -1.4324, 0.7022])),
+ ('linear.weight', tensor([[-0.3915, -0.6176],
+ [ 0.6062, -0.5992],
+ [ 0.4452, -0.2843]])),
+ ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))])
+
For more information, check out:
* Saving and loading: https://pytorch.org/tutorials/beginner/saving_loading_models.html
* Serialization semantics: https://pytorch.org/docs/master/notes/serialization.html
* What is a state dict? https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html
+Module Initialization
+---------------------
+
+By default, parameters and floating-point buffers for modules provided by :mod:`torch.nn` are initialized during
+module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to
+perform well historically for the module type. For certain use cases, it may be desired to initialize with a different
+dtype, device (e.g. GPU), or initialization technique.
+
+Examples:
+
+.. code-block:: python
+
+ # Initialize module directly onto GPU.
+ m = nn.Linear(5, 3, device='cuda')
+
+ # Initialize module with 16-bit floating point parameters.
+ m = nn.Linear(5, 3, dtype=torch.half)
+
+ # Skip default parameter initialization and perform custom (e.g. orthogonal) initialization.
+ m = torch.nn.utils.skip_init(nn.Linear, 5, 3)
+ nn.init.orthogonal_(m.weight)
+
+Note that the device and dtype options demonstrated above also apply to any floating-point buffers registered
+for the module:
+
+.. code-block:: python
+
+ m = nn.BatchNorm2d(3, dtype=torch.half)
+ print(m.running_mean)
+ : tensor([0., 0., 0.], dtype=torch.float16)
+
+While module writers can use any device or dtype to initialize parameters in their custom modules, good practice is
+to use ``dtype=torch.float`` and ``device='cpu'`` by default as well. Optionally, you can provide full flexibility
+in these areas for your custom module by conforming to the convention demonstrated above that all
+:mod:`torch.nn` modules follow:
+
+* Provide a ``device`` constructor kwarg that applies to any parameters / buffers registered by the module.
+* Provide a ``dtype`` constructor kwarg that applies to any parameters / floating-point buffers registered by
+ the module.
+* Only use initialization functions (i.e. functions from :mod:`torch.nn.init`) on parameters and buffers within the
+ module's constructor. Note that this is only required to use :func:`~torch.nn.utils.skip_init`; see
+ `this page <https://pytorch.org/tutorials/prototype/skip_param_init.html#updating-modules-to-support-skipping-initialization>`_ for an explanation.
+
+For more information, check out:
+
+* Skipping module parameter initialization: https://pytorch.org/tutorials/prototype/skip_param_init.html
+
Module Hooks
------------
Thus, these hooks can be used to either execute arbitrary code along the regular module forward/backward or
modify some inputs/outputs without having to change the module's ``forward()`` function.
+Below is an example demonstrating usage of forward and backward hooks:
+
+.. code-block:: python
+
+ torch.manual_seed(1)
+
+ def forward_pre_hook(m, inputs):
+ # Allows for examination and modification of the input before the forward pass.
+ # Note that inputs are always wrapped in a tuple.
+ input = inputs[0]
+ return input + 1.
+
+ def forward_hook(m, inputs, output):
+ # Allows for examination of inputs / outputs and modification of the outputs
+ # after the forward pass. Note that inputs are always wrapped in a tuple while outputs
+ # are passed as-is.
+
+ # Residual computation a la ResNet.
+ return output + inputs[0]
+
+ def backward_hook(m, grad_inputs, grad_outputs):
+ # Allows for examination of grad_inputs / grad_outputs and modification of
+ # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and
+ # grad_outputs are always wrapped in tuples.
+ new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
+ return new_grad_inputs
+
+ # Create sample module & input.
+ m = nn.Linear(3, 3)
+ x = torch.randn(2, 3, requires_grad=True)
+
+ # ==== Demonstrate forward hooks. ====
+ # Run input through module before and after adding hooks.
+ print('output with no forward hooks: {}'.format(m(x)))
+ : output with no forward hooks: tensor([[-0.5059, -0.8158, 0.2390],
+ [-0.0043, 0.4724, -0.1714]], grad_fn=<AddmmBackward>)
+
+ # Note that the modified input results in a different output.
+ forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
+ print('output with forward pre hook: {}'.format(m(x)))
+ : output with forward pre hook: tensor([[-0.5752, -0.7421, 0.4942],
+ [-0.0736, 0.5461, 0.0838]], grad_fn=<AddmmBackward>)
+
+ # Note the modified output.
+ forward_hook_handle = m.register_forward_hook(forward_hook)
+ print('output with both forward hooks: {}'.format(m(x)))
+ : output with both forward hooks: tensor([[-1.0980, 0.6396, 0.4666],
+ [ 0.3634, 0.6538, 1.0256]], grad_fn=<AddBackward0>)
+
+ # Remove hooks; note that the output here matches the output before adding hooks.
+ forward_pre_hook_handle.remove()
+ forward_hook_handle.remove()
+ print('output after removing forward hooks: {}'.format(m(x)))
+ : output after removing forward hooks: tensor([[-0.5059, -0.8158, 0.2390],
+ [-0.0043, 0.4724, -0.1714]], grad_fn=<AddmmBackward>)
+
+ # ==== Demonstrate backward hooks. ====
+ m(x).sum().backward()
+ print('x.grad with no backwards hook: {}'.format(x.grad))
+ : x.grad with no backwards hook: tensor([[ 0.4497, -0.5046, 0.3146],
+ [ 0.4497, -0.5046, 0.3146]])
+
+ # Clear gradients before running backward pass again.
+ m.zero_grad()
+ x.grad.zero_()
+
+ m.register_full_backward_hook(backward_hook)
+ m(x).sum().backward()
+ print('x.grad with backwards hook: {}'.format(x.grad))
+ : x.grad with backwards hook: tensor([[42., 42., 42.],
+ [42., 42., 42.]])
+
Advanced Features
-----------------
PyTorch also provides several more advanced features that are designed to work with modules. All these functionalities
-are "inherited" when writing a new module. In-depth discussion of these features can be found in the links below.
+are available for custom-written modules, with the small caveat that certain features may require modules to conform
+to particular constraints in order to be supported. In-depth discussion of these features and the corresponding
+requirements can be found in the links below.
-For more information, check out:
+Distributed Training
+********************
+
+Various methods for distributed training exist within PyTorch, both for scaling up training using multiple GPUs
+as well as training across multiple machines. Check out the
+`distributed training overview page <https://pytorch.org/tutorials/beginner/dist_overview.html>`_ for
+detailed information on how to utilize these.
+
+Profiling Performance
+*********************
+
+The `PyTorch Profiler <https://pytorch.org/tutorials/beginner/profiler.html>`_ can be useful for identifying
+performance bottlenecks within your models. It measures and outputs performance characteristics for
+both memory usage and time spent.
+
+Improving Performance with Quantization
+***************************************
+
+Applying quantization techniques to modules can improve performance and memory usage by utilizing lower
+bitwidths than floating-point precision. Check out the various PyTorch-provided mechanisms for quantization
+`here <https://pytorch.org/docs/stable/quantization.html>`_.
+
+Improving Memory Usage with Pruning
+***********************************
+
+Large deep learning models are often over-parametrized, resulting in high memory usage. To combat this, PyTorch
+provides mechanisms for model pruning, which can help reduce memory usage while maintaining task accuracy. The
+`Pruning tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_ describes how to utilize
+the pruning techniques PyTorch provides or define custom pruning techniques as necessary.
+
+Deploying with TorchScript
+**************************
+
+When deploying a model for use in production, the overhead of Python can be unacceptable due to its poor
+performance characteristics. For cases like this,
+`TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ provides a way to load
+and run an optimized model program from outside of Python, such as within a C++ program.
+
+Parametrizations
+****************
+
+For certain applications, it can be beneficial to constrain the parameter space during model training. For example,
+enforcing orthogonality of the learned parameters can improve convergence for RNNs. PyTorch provides a mechanism for
+applying `parametrizations <https://pytorch.org/tutorials/intermediate/parametrizations.html>`_ such as this, and
+further allows for custom constraints to be defined.
+
+Transforming Modules with FX
+****************************
-* Profiling: https://pytorch.org/tutorials/beginner/profiler.html
-* Pruning: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
-* Quantization: https://pytorch.org/tutorials/recipes/quantization.html
-* Exporting modules to TorchScript (e.g. for usage from C++):
- https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
+The `FX <https://pytorch.org/docs/stable/fx.html>`_ component of PyTorch provides a flexible way to transform
+modules by operating directly on module computation graphs. This can be used to programmatically generate or
+manipulate modules for a broad array of use cases. To explore FX, check out these examples of using FX for
+`convolution + batch norm fusion <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_ and
+`CPU performance analysis <https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html>`_.