From: crcrpar Date: Fri, 29 Mar 2019 13:41:49 +0000 (-0700) Subject: make InstanceNorm1d raise an error if the input is 2D (#11992) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~560 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cf444f354450821a9e4288a8eb32a2a6571754e9;p=platform%2Fupstream%2Fpytorch.git make InstanceNorm1d raise an error if the input is 2D (#11992) Summary: Resolves #11991 . Any comment is welcome. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11992 Differential Revision: D14680974 Pulled By: soumith fbshipit-source-id: 8e287a9c32bf43b35edc9d127f16ed6b72c61d91 --- diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 3a0c452..375240c 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -57,7 +57,7 @@ class _InstanceNorm(_BatchNorm): @weak_module class InstanceNorm1d(_InstanceNorm): - r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D + r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization`_ . @@ -126,8 +126,15 @@ class InstanceNorm1d(_InstanceNorm): @weak_script_method def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' + if input.dim() == 2: + raise ValueError( + 'InstanceNorm1d returns 0-filled tensor to 2D tensor.' + 'This is because InstanceNorm1d reshapes inputs to' + '(1, N * C, ...) from (N, C,...) and this makes' + 'variances 0.' + ) + if input.dim() != 3: + raise ValueError('expected 3D input (got {}D input)' .format(input.dim()))