Stop erroneously running aten::warn (#15124)
authorRichard Zou <zou3519@gmail.com>
Wed, 12 Dec 2018 19:32:05 +0000 (11:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 19:35:23 +0000 (11:35 -0800)
Summary:
Fixes #15119. Before this PR, we were propagating constants through
aten::warn AND running it as a part of shape analysis.
This caused aten::warn to be run regardless of if it is
supposed to be run dynamically. This PR adds an exclusion for aten::warn
in constant propagation and shape analysis, similar to that of prim::RaiseException.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15124

Differential Revision: D13432815

Pulled By: zou3519

fbshipit-source-id: 15ab533ce2accb2da3fd4e569070c7979ce61708

test/test_jit.py
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/shape_analysis.cpp

index e6696cb..0000e93 100644 (file)
@@ -2053,6 +2053,21 @@ class TestJit(JitTestCase):
 
         self.assertExpectedGraph(fn.graph)
 
+    def test_no_erroneous_warnings(self):
+        import warnings
+
+        def fn(x):
+            if bool(x > 0):
+                warnings.warn('This should NOT be printed')
+                x += 1
+            return x
+
+        with warnings.catch_warnings(record=True) as warns:
+            fn_script = torch.jit.script(fn)
+            fn_script(torch.tensor(0))
+        warns = [str(w.message) for w in warns]
+        self.assertEqual(len(warns), 0)
+
 
 class TestBatched(TestCase):
     # generate random examples and create an batchtensor with them
index 967be5d..a1a6c1a 100644 (file)
@@ -18,6 +18,7 @@ std::unordered_set<Symbol> skip_list = {
   prim::Loop, //TODO: handle Loop
   prim::Print,
   prim::RaiseException,
+  aten::warn,
   prim::PythonOp, //may have side effects
   prim::Constant,
   prim::Undefined,
index bdcf47e..cedc221 100644 (file)
@@ -437,6 +437,7 @@ class ShapePropagator {
       case prim::PythonOp:
       case prim::Print:
       case prim::RaiseException:
+      case aten::warn:
       case prim::Undefined: {
         setUnshapedType(node);
         return;