From: Richard Zou Date: Wed, 12 Dec 2018 19:32:05 +0000 (-0800) Subject: Stop erroneously running aten::warn (#15124) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2293 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b34ab435efd9b4d839171d27bd6a0f321178d46b;p=platform%2Fupstream%2Fpytorch.git Stop erroneously running aten::warn (#15124) Summary: Fixes #15119. Before this PR, we were propagating constants through aten::warn AND running it as a part of shape analysis. This caused aten::warn to be run regardless of if it is supposed to be run dynamically. This PR adds an exclusion for aten::warn in constant propagation and shape analysis, similar to that of prim::RaiseException. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15124 Differential Revision: D13432815 Pulled By: zou3519 fbshipit-source-id: 15ab533ce2accb2da3fd4e569070c7979ce61708 --- diff --git a/test/test_jit.py b/test/test_jit.py index e6696cb..0000e93 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2053,6 +2053,21 @@ class TestJit(JitTestCase): self.assertExpectedGraph(fn.graph) + def test_no_erroneous_warnings(self): + import warnings + + def fn(x): + if bool(x > 0): + warnings.warn('This should NOT be printed') + x += 1 + return x + + with warnings.catch_warnings(record=True) as warns: + fn_script = torch.jit.script(fn) + fn_script(torch.tensor(0)) + warns = [str(w.message) for w in warns] + self.assertEqual(len(warns), 0) + class TestBatched(TestCase): # generate random examples and create an batchtensor with them diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 967be5d..a1a6c1a 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -18,6 +18,7 @@ std::unordered_set skip_list = { prim::Loop, //TODO: handle Loop prim::Print, prim::RaiseException, + aten::warn, prim::PythonOp, //may have side effects prim::Constant, prim::Undefined, diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index bdcf47e..cedc221 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -437,6 +437,7 @@ class ShapePropagator { case prim::PythonOp: case prim::Print: case prim::RaiseException: + case aten::warn: case prim::Undefined: { setUnshapedType(node); return;