From c79e305addc4f5494f29826da8c34768fa80f942 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 4 Dec 2018 21:35:48 -0800 Subject: [PATCH] Don't DCE PythonOp Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14773 Reviewed By: eellison Differential Revision: D13327673 Pulled By: suo fbshipit-source-id: 236db3407c7eacac470530836e3d4d0dc323110c --- test/test_jit.py | 9 --------- torch/csrc/jit/passes/dead_code_elimination.cpp | 4 +--- torch/nn/modules/batchnorm.py | 3 +++ 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 0821fa4..52d1a7d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8513,15 +8513,6 @@ a") foo(torch.tensor(1)) @torch.jit.script - def foo(): - a = Exception() - raise a - - # a gets DCEd because the expression following raise is ignored - with self.assertRaisesRegex(torch.jit.Error, "failed in interpreter"): - foo() - - @torch.jit.script def foo_except_used(): a = Exception() print(a) diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 637f39e..bd5d5a3 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -177,14 +177,12 @@ class DeadCodeEliminator { } bool hasSideEffects(Node* node) { - // FIXME: PythonOp should be treated as having side effects as well! - // Unfortunately ONNX depends on it getting removed in this pass, so - // it's not a simple change. auto it = memo_.find(node); if (it != memo_.end()) return it->second; bool has_side_effects = node->kind() == prim::Print || node->kind() == prim::RaiseException || + node->kind() == prim::PythonOp || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e6dbc21..fbd40be 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -162,6 +162,7 @@ class BatchNorm1d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' @@ -235,6 +236,7 @@ class BatchNorm2d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' @@ -309,6 +311,7 @@ class BatchNorm3d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' -- 2.7.4