From cf444f354450821a9e4288a8eb32a2a6571754e9 Mon Sep 17 00:00:00 2001 From: crcrpar Date: Fri, 29 Mar 2019 06:41:49 -0700 Subject: [PATCH] make InstanceNorm1d raise an error if the input is 2D (#11992) Summary: Resolves #11991 . Any comment is welcome. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11992 Differential Revision: D14680974 Pulled By: soumith fbshipit-source-id: 8e287a9c32bf43b35edc9d127f16ed6b72c61d91 --- torch/nn/modules/instancenorm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 3a0c452..375240c 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -57,7 +57,7 @@ class _InstanceNorm(_BatchNorm): @weak_module class InstanceNorm1d(_InstanceNorm): - r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D + r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization`_ . @@ -126,8 +126,15 @@ class InstanceNorm1d(_InstanceNorm): @weak_script_method def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' + if input.dim() == 2: + raise ValueError( + 'InstanceNorm1d returns 0-filled tensor to 2D tensor.' + 'This is because InstanceNorm1d reshapes inputs to' + '(1, N * C, ...) from (N, C,...) and this makes' + 'variances 0.' + ) + if input.dim() != 3: + raise ValueError('expected 3D input (got {}D input)' .format(input.dim())) -- 2.7.4