Turn off layer norm in jit symbolic differentiation (#63816)
authorXiaodong Wang <xdwang@fb.com>
Tue, 24 Aug 2021 22:45:59 +0000 (15:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 22:47:13 +0000 (15:47 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63816

Test Plan:
Confirmed this can rescue the NE:

https://www.internalfb.com/mast/job/torchx_xdwang-SparseNNApplication_72cf593d

Reviewed By: ngimel

Differential Revision: D30498746

fbshipit-source-id: 4a387f32ee2f70685de6104459c7f21bfbddc187

torch/csrc/jit/runtime/symbolic_script.cpp
torch/testing/_internal/jit_metaprogramming_utils.py

index 29ce74a..6f2acca 100644 (file)
@@ -1141,7 +1141,7 @@ const std::vector<std::string> functions = {
 
             return output, backward
 
-        def layer_norm(input : Tensor,
+        def layer_norm_disabled(input : Tensor,
                        normalized_shape : List[int],
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],
index 350866c..75b1615 100644 (file)
@@ -144,14 +144,14 @@ nn_functional_tests = [
         'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
     ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
     ('layer_norm', (S, S, S, S), ([5],), '',
-     (True, ['aten::native_layer_norm'])),
+     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
     ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
-     (True, ['aten::native_layer_norm'])),
+     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
     ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
-     (True, ['aten::native_layer_norm'])),
+     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
     ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
                                   non_differentiable(torch.rand(S))), 'with_weight_and_bias',
-     (True, ['aten::native_layer_norm'])),
+     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
     ('group_norm', (S, S, S), (1, torch.rand(5),),),
     ('local_response_norm', (S, S, S), (2, ),),
     ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),