From: Wanchao Liang Date: Wed, 3 Apr 2019 23:50:46 +0000 (-0700) Subject: Fix layernorm ad formula on weight and bias (#18233) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~433 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=843e6234f5f87f281a487fd4f8434e07101ee3ed;p=platform%2Fupstream%2Fpytorch.git Fix layernorm ad formula on weight and bias (#18233) Summary: Fix the layernorm formula when weight and bias passed in. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18233 Differential Revision: D14760375 Pulled By: wanchaol fbshipit-source-id: d6bd3b137bc04c391aa5c24d021d1f811ba2a877 --- diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index e4be451..b1d3b31 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -384,7 +384,7 @@ Tensor instance_norm( Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight /* optional */, const Tensor& bias /* optional */, double eps, bool cudnn_enabled) { - + int64_t normalized_ndim = normalized_shape.size(); AT_CHECK(normalized_ndim >= 1, diff --git a/test/test_jit.py b/test/test_jit.py index 70d01df..79cfba3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12170,10 +12170,15 @@ nn_functional_tests = [ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ), '', (True, 'aten::_batch_norm_impl_index')), ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), - ('layer_norm', (S, S, S, S), ([5],),), - ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'), - ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'), - ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'), + ('layer_norm', (S, S, S, S), ([5],), '', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), + non_differentiable(torch.rand(S))), 'with_weight_and_bias', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), ('group_norm', (S, S, S), (1, torch.rand(5),),), ('local_response_norm', (S, S, S), (2, ),), ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')), diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index c897d73..3cfcd7c 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -660,6 +660,7 @@ const std::vector functions = { return torch.adaptive_avg_pool3d(self, output_size), backward + def batch_norm(input : Tensor, weight : Optional[Tensor], bias : Optional[Tensor], @@ -685,21 +686,27 @@ const std::vector functions = { return output, backward def layer_norm(input : Tensor, - normalied_shape : List[int], + normalized_shape : List[int], weight : Optional[Tensor], bias : Optional[Tensor], eps : float, cudnn_enable : bool): + input_ndim = input.dim() + normalized_ndim = len(normalized_shape) + n = 1 + for i in range(input_ndim - normalized_ndim): + n *= input.size(i) + + input_reshape = input.contiguous().view(1, n, -1) + bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index( - input, weight, bias, None, None, True, + input_reshape, None, None, None, None, True, 0.0, eps, cudnn_enable) - has_weight = weight is not None - has_bias = bias is not None - bn_out = bn_out.view(input.sizes()) + bn_out = bn_out.view(input.size()) if weight is not None and bias is not None: - output = bias.addcmul(bn_out, weight) + output = bias.addcmul(bn_out, weight, value=1) elif weight is not None: output = bn_out.mul(weight) elif bias is not None: @@ -708,16 +715,32 @@ const std::vector functions = { output = bn_out def backward(grad_output): - if weight is not None: - grad_output = grad_output * torch.t(weight) - weight = grad_output * torch.t(bn_out) + if weight is not None and bias is not None: + grad_bn_out = grad_output * weight + grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size()) + grad_bias = grad_output._grad_sum_to_size(bias.size()) + elif weight is not None: + grad_bn_out = grad_output * weight + grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size()) + grad_bias = None + elif bias is not None: + grad_bn_out = grad_output + grad_weight= None + grad_bias = grad_output._grad_sum_to_size(bias.size()) + else: + grad_bn_out = grad_output + grad_weight= None + grad_bias = None - grad_output = grad_output.reshape(input.sizes()) - dinput, dweight, dbias = torch._batch_norm_impl_index_backward( - impl_idx, input, grad_output, weight, None, None, - save1, save2, True, eps, [True, has_weight, has_bias]) - return dinput, None, dweight, dbias, None, None + grad_bn_out = grad_bn_out.contiguous().view(1, n, -1) + + grad_input, _, _ = torch._batch_norm_impl_index_backward( + impl_idx, input_reshape, grad_bn_out, None, None, None, + save1, save2, True, eps, [True, False, False]) + + grad_input = grad_input.view(input.size()) + return grad_input, None, grad_weight, grad_bias, None, None return output, backward