From: Adam Paszke Date: Wed, 5 Dec 2018 05:35:48 +0000 (-0800) Subject: Don't DCE PythonOp X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2464 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c79e305addc4f5494f29826da8c34768fa80f942;p=platform%2Fupstream%2Fpytorch.git Don't DCE PythonOp Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14773 Reviewed By: eellison Differential Revision: D13327673 Pulled By: suo fbshipit-source-id: 236db3407c7eacac470530836e3d4d0dc323110c --- diff --git a/test/test_jit.py b/test/test_jit.py index 0821fa4..52d1a7d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8513,15 +8513,6 @@ a") foo(torch.tensor(1)) @torch.jit.script - def foo(): - a = Exception() - raise a - - # a gets DCEd because the expression following raise is ignored - with self.assertRaisesRegex(torch.jit.Error, "failed in interpreter"): - foo() - - @torch.jit.script def foo_except_used(): a = Exception() print(a) diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 637f39e..bd5d5a3 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -177,14 +177,12 @@ class DeadCodeEliminator { } bool hasSideEffects(Node* node) { - // FIXME: PythonOp should be treated as having side effects as well! - // Unfortunately ONNX depends on it getting removed in this pass, so - // it's not a simple change. auto it = memo_.find(node); if (it != memo_.end()) return it->second; bool has_side_effects = node->kind() == prim::Print || node->kind() == prim::RaiseException || + node->kind() == prim::PythonOp || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e6dbc21..fbd40be 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -162,6 +162,7 @@ class BatchNorm1d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' @@ -235,6 +236,7 @@ class BatchNorm2d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' @@ -309,6 +311,7 @@ class BatchNorm3d(_BatchNorm): https://arxiv.org/abs/1502.03167 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)'