Don't DCE PythonOp
authorAdam Paszke <adam.paszke@gmail.com>
Wed, 5 Dec 2018 05:35:48 +0000 (21:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 05:37:36 +0000 (21:37 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14773

Reviewed By: eellison

Differential Revision: D13327673

Pulled By: suo

fbshipit-source-id: 236db3407c7eacac470530836e3d4d0dc323110c

test/test_jit.py
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/nn/modules/batchnorm.py

index 0821fa4..52d1a7d 100644 (file)
@@ -8513,15 +8513,6 @@ a")
             foo(torch.tensor(1))
 
         @torch.jit.script
-        def foo():
-            a = Exception()
-            raise a
-
-        # a gets DCEd because the expression following raise is ignored
-        with self.assertRaisesRegex(torch.jit.Error, "failed in interpreter"):
-            foo()
-
-        @torch.jit.script
         def foo_except_used():
             a = Exception()
             print(a)
index 637f39e..bd5d5a3 100644 (file)
@@ -177,14 +177,12 @@ class DeadCodeEliminator {
   }
 
   bool hasSideEffects(Node* node) {
-    // FIXME: PythonOp should be treated as having side effects as well!
-    //        Unfortunately ONNX depends on it getting removed in this pass, so
-    //        it's not a simple change.
     auto it = memo_.find(node);
     if (it != memo_.end())
       return it->second;
     bool has_side_effects = node->kind() == prim::Print ||
         node->kind() == prim::RaiseException ||
+        node->kind() == prim::PythonOp ||
         std::any_of(node->blocks().begin(),
                     node->blocks().end(),
                     [&](Block* b) {
index e6dbc21..fbd40be 100644 (file)
@@ -162,6 +162,7 @@ class BatchNorm1d(_BatchNorm):
         https://arxiv.org/abs/1502.03167
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 2 and input.dim() != 3:
             raise ValueError('expected 2D or 3D input (got {}D input)'
@@ -235,6 +236,7 @@ class BatchNorm2d(_BatchNorm):
         https://arxiv.org/abs/1502.03167
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 4:
             raise ValueError('expected 4D input (got {}D input)'
@@ -309,6 +311,7 @@ class BatchNorm3d(_BatchNorm):
         https://arxiv.org/abs/1502.03167
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'