Add an identity module (#19249)
authorMilesCranmer <miles.cranmer@gmail.com>
Fri, 19 Apr 2019 17:08:50 +0000 (10:08 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 17:12:18 +0000 (10:12 -0700)
Summary:
This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR.

 There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model.

See discussion on #9160. The current solution is to use nn.Sequential(). However this won't work as follows:

```python
batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
    batch_norm = Identity
```

Then in your network, you have:

```python
nn.Sequential(
    ...
    batch_norm(N, momentum=0.05),
    ...
)
```

If you try to simply set `Identity = nn.Sequential`, this will fail since `nn.Sequential` expects modules as arguments. Of course there are many ways to get around this, including:

- Conditionally adding modules to an existing Sequential module
- Not using Sequential but writing the usual `forward` function with an if statement
- ...

**However, I think that an identity module is the most pythonic strategy,** assuming you want to use nn.Sequential.

Using the very simple class (this isn't the same as the one in my commit):

```python
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
```

we can get around using nn.Sequential, and `batch_norm(N, momentum=0.05)` will work. There are of course other situations this would be useful.

Thank you.
Best,
Miles
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19249

Differential Revision: D15012969

Pulled By: ezyang

fbshipit-source-id: 9f47e252137a1679e306fd4c169dca832eb82c0c

docs/source/nn.rst
torch/nn/modules/__init__.py
torch/nn/modules/linear.py

index 1d4bc3e..24e2caf 100644 (file)
@@ -529,6 +529,12 @@ Recurrent layers
 Linear layers
 ----------------------------------
 
+:hidden:`Identity`
+~~~~~~~~~~~~~~~~
+
+.. autoclass:: Identity
+    :members:
+
 :hidden:`Linear`
 ~~~~~~~~~~~~~~~~
 
index 55b294b..be0519b 100644 (file)
@@ -1,5 +1,5 @@
 from .module import Module
-from .linear import Linear, Bilinear
+from .linear import Identity, Linear, Bilinear
 from .conv import Conv1d, Conv2d, Conv3d, \
     ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
 from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
@@ -29,7 +29,7 @@ from .fold import Fold, Unfold
 from .adaptive import AdaptiveLogSoftmaxWithLoss
 
 __all__ = [
-    'Module', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
+    'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
     'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
     'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'Hardshrink',
     'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
index 442034e..f5c2c02 100644 (file)
@@ -9,6 +9,31 @@ from ..._jit_internal import weak_module, weak_script_method
 
 
 @weak_module
+class Identity(Module):
+    r"""A placeholder identity operator that is argument-insensitive.
+
+    Args:
+        args: any argument (unused)
+        kwargs: any keyword argument (unused)
+
+    Examples::
+
+        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 20])
+
+    """
+    def __init__(self, *args, **kwargs):
+        super(Identity, self).__init__()
+
+    @weak_script_method
+    def forward(self, input):
+        return input
+
+
+@weak_module
 class Linear(Module):
     r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`