make InstanceNorm1d raise an error if the input is 2D (#11992)
authorcrcrpar <masaki.kozuki.2014@gmail.com>
Fri, 29 Mar 2019 13:41:49 +0000 (06:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 13:50:04 +0000 (06:50 -0700)
Summary:
Resolves #11991 .

Any comment is welcome.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11992

Differential Revision: D14680974

Pulled By: soumith

fbshipit-source-id: 8e287a9c32bf43b35edc9d127f16ed6b72c61d91

torch/nn/modules/instancenorm.py

index 3a0c452..375240c 100644 (file)
@@ -57,7 +57,7 @@ class _InstanceNorm(_BatchNorm):
 
 @weak_module
 class InstanceNorm1d(_InstanceNorm):
-    r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D
+    r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
     inputs with optional additional channel dimension) as described in the paper
     `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
 
@@ -126,8 +126,15 @@ class InstanceNorm1d(_InstanceNorm):
 
     @weak_script_method
     def _check_input_dim(self, input):
-        if input.dim() != 2 and input.dim() != 3:
-            raise ValueError('expected 2D or 3D input (got {}D input)'
+        if input.dim() == 2:
+            raise ValueError(
+                'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
+                'This is because InstanceNorm1d reshapes inputs to'
+                '(1, N * C, ...) from (N, C,...) and this makes'
+                'variances 0.'
+            )
+        if input.dim() != 3:
+            raise ValueError('expected 3D input (got {}D input)'
                              .format(input.dim()))