From: MilesCranmer Date: Fri, 19 Apr 2019 17:08:50 +0000 (-0700) Subject: Add an identity module (#19249) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~128 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=30292d994fabf24146fe0dbcfe9621d4f87325e2;p=platform%2Fupstream%2Fpytorch.git Add an identity module (#19249) Summary: This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR. There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model. See discussion on #9160. The current solution is to use nn.Sequential(). However this won't work as follows: ```python batch_norm = nn.BatchNorm2d if dont_use_batch_norm: batch_norm = Identity ``` Then in your network, you have: ```python nn.Sequential( ... batch_norm(N, momentum=0.05), ... ) ``` If you try to simply set `Identity = nn.Sequential`, this will fail since `nn.Sequential` expects modules as arguments. Of course there are many ways to get around this, including: - Conditionally adding modules to an existing Sequential module - Not using Sequential but writing the usual `forward` function with an if statement - ... **However, I think that an identity module is the most pythonic strategy,** assuming you want to use nn.Sequential. Using the very simple class (this isn't the same as the one in my commit): ```python class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): return x ``` we can get around using nn.Sequential, and `batch_norm(N, momentum=0.05)` will work. There are of course other situations this would be useful. Thank you. Best, Miles Pull Request resolved: https://github.com/pytorch/pytorch/pull/19249 Differential Revision: D15012969 Pulled By: ezyang fbshipit-source-id: 9f47e252137a1679e306fd4c169dca832eb82c0c --- diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 1d4bc3e..24e2caf 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -529,6 +529,12 @@ Recurrent layers Linear layers ---------------------------------- +:hidden:`Identity` +~~~~~~~~~~~~~~~~ + +.. autoclass:: Identity + :members: + :hidden:`Linear` ~~~~~~~~~~~~~~~~ diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 55b294b..be0519b 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -1,5 +1,5 @@ from .module import Module -from .linear import Linear, Bilinear +from .linear import Identity, Linear, Bilinear from .conv import Conv1d, Conv2d, Conv3d, \ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ @@ -29,7 +29,7 @@ from .fold import Fold, Unfold from .adaptive import AdaptiveLogSoftmaxWithLoss __all__ = [ - 'Module', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', + 'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin', diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 442034e..f5c2c02 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -9,6 +9,31 @@ from ..._jit_internal import weak_module, weak_script_method @weak_module +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + @weak_script_method + def forward(self, input): + return input + + +@weak_module class Linear(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`