Summary:
Fixes #15119. Before this PR, we were propagating constants through
aten::warn AND running it as a part of shape analysis.
This caused aten::warn to be run regardless of if it is
supposed to be run dynamically. This PR adds an exclusion for aten::warn
in constant propagation and shape analysis, similar to that of prim::RaiseException.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15124
Differential Revision:
D13432815
Pulled By: zou3519
fbshipit-source-id:
15ab533ce2accb2da3fd4e569070c7979ce61708
self.assertExpectedGraph(fn.graph)
+ def test_no_erroneous_warnings(self):
+ import warnings
+
+ def fn(x):
+ if bool(x > 0):
+ warnings.warn('This should NOT be printed')
+ x += 1
+ return x
+
+ with warnings.catch_warnings(record=True) as warns:
+ fn_script = torch.jit.script(fn)
+ fn_script(torch.tensor(0))
+ warns = [str(w.message) for w in warns]
+ self.assertEqual(len(warns), 0)
+
class TestBatched(TestCase):
# generate random examples and create an batchtensor with them
prim::Loop, //TODO: handle Loop
prim::Print,
prim::RaiseException,
+ aten::warn,
prim::PythonOp, //may have side effects
prim::Constant,
prim::Undefined,
case prim::PythonOp:
case prim::Print:
case prim::RaiseException:
+ case aten::warn:
case prim::Undefined: {
setUnshapedType(node);
return;