Fix layernorm ad formula on weight and bias (#18233)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Wed, 3 Apr 2019 23:50:46 +0000 (16:50 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 23:58:33 +0000 (16:58 -0700)
Summary:
Fix the layernorm formula when weight and bias passed in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18233

Differential Revision: D14760375

Pulled By: wanchaol

fbshipit-source-id: d6bd3b137bc04c391aa5c24d021d1f811ba2a877

aten/src/ATen/native/Normalization.cpp
test/test_jit.py
torch/csrc/jit/symbolic_script.cpp

index e4be451..b1d3b31 100644 (file)
@@ -384,7 +384,7 @@ Tensor instance_norm(
 Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape,
     const Tensor& weight /* optional */, const Tensor& bias /* optional */,
     double eps, bool cudnn_enabled) {
-
+  
     int64_t normalized_ndim = normalized_shape.size();
 
     AT_CHECK(normalized_ndim >= 1,
index 70d01df..79cfba3 100644 (file)
@@ -12170,10 +12170,15 @@ nn_functional_tests = [
     ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
         '', (True, 'aten::_batch_norm_impl_index')),
     ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
-    ('layer_norm', (S, S, S, S), ([5],),),
-    ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
-    ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'),
-    ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
+    ('layer_norm', (S, S, S, S), ([5],), '',
+     (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+    ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+     (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+    ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+     (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+    ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+                                  non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+     (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
     ('group_norm', (S, S, S), (1, torch.rand(5),),),
     ('local_response_norm', (S, S, S), (2, ),),
     ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')),
index c897d73..3cfcd7c 100644 (file)
@@ -660,6 +660,7 @@ const std::vector<std::string> functions = {
 
             return torch.adaptive_avg_pool3d(self, output_size), backward
 
+
         def batch_norm(input : Tensor,
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],
@@ -685,21 +686,27 @@ const std::vector<std::string> functions = {
             return output, backward
 
         def layer_norm(input : Tensor,
-                       normalied_shape : List[int],
+                       normalized_shape : List[int],
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],
                        eps : float,
                        cudnn_enable : bool):
 
+            input_ndim = input.dim()
+            normalized_ndim = len(normalized_shape)
+            n = 1
+            for i in range(input_ndim - normalized_ndim):
+                n *= input.size(i)
+
+            input_reshape = input.contiguous().view(1, n, -1)
+
             bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index(
-                input, weight, bias, None, None, True,
+                input_reshape, None, None, None, None, True,
                 0.0, eps, cudnn_enable)
-            has_weight = weight is not None
-            has_bias = bias is not None
 
-            bn_out = bn_out.view(input.sizes())
+            bn_out = bn_out.view(input.size())
             if weight is not None and bias is not None:
-                output = bias.addcmul(bn_out, weight)
+                output = bias.addcmul(bn_out, weight, value=1)
             elif weight is not None:
                 output = bn_out.mul(weight)
             elif bias is not None:
@@ -708,16 +715,32 @@ const std::vector<std::string> functions = {
                 output = bn_out
 
             def backward(grad_output):
-                if weight is not None:
-                    grad_output = grad_output * torch.t(weight)
-                    weight = grad_output * torch.t(bn_out)
+                if weight is not None and bias is not None:
+                    grad_bn_out = grad_output * weight
+                    grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+                    grad_bias = grad_output._grad_sum_to_size(bias.size())
+                elif weight is not None:
+                    grad_bn_out = grad_output * weight
+                    grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+                    grad_bias = None
+                elif bias is not None:
+                    grad_bn_out = grad_output
+                    grad_weight= None
+                    grad_bias = grad_output._grad_sum_to_size(bias.size())
+                else:
+                    grad_bn_out = grad_output
+                    grad_weight= None
+                    grad_bias = None
 
-                grad_output = grad_output.reshape(input.sizes())
 
-                dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
-                    impl_idx, input, grad_output, weight, None, None,
-                    save1, save2, True, eps, [True, has_weight, has_bias])
-                return dinput, None, dweight, dbias, None, None
+                grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
+
+                grad_input, _, _ = torch._batch_norm_impl_index_backward(
+                    impl_idx, input_reshape, grad_bn_out, None, None, None,
+                    save1, save2, True, eps, [True, False, False])
+
+                grad_input = grad_input.view(input.size())
+                return grad_input, None, grad_weight, grad_bias, None, None
 
             return output, backward