From 7d601715e5ca21cd15319d06a2f52aca77e60f9b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 15 Jan 2019 10:56:17 -0800 Subject: [PATCH] Constant prop prim::None (#15979) Summary: Previously we were only constant propping prim::Constants, but we should be constant propping prim::None as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15979 Differential Revision: D13664692 Pulled By: eellison fbshipit-source-id: 01839403576c21fc030c427e49275b8e1210fa8f --- test/test_jit.py | 20 ++++++++++++++++++++ torch/csrc/jit/passes/constant_propagation.cpp | 17 ++++++++++++++--- torch/csrc/jit/register_prim_ops.cpp | 2 +- torch/nn/functional.py | 4 ++-- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 22e0d4d..43a7407 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1699,6 +1699,26 @@ class TestJit(JitTestCase): self.run_pass('constant_propagation', constant_prop.graph) self.assertExpected(canonical(constant_prop.graph)) + def test_constant_prop_none(self): + @torch.jit.script + def typed_none(): + # type: () -> Optional[int] + return None + + @torch.jit.script + def constant_prop(): + a = typed_none() + b = typed_none() + if (a is None and b is None): + a = 2 + else: + a = 1 + return a + + self.run_pass('constant_propagation', constant_prop.graph) + graph_str = str(constant_prop.graph) + self.assertTrue(graph_str.count("prim::None") == 0) + def test_trace_records_names(self): def foo(bar, baz): baz = bar + 3 diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index b8d6f2b..2665982 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -29,7 +29,11 @@ std::vector runNode(Node* n) { auto op = getOperation(n); Stack stack; for (auto input : n->inputs()) { - stack.push_back(*(toIValue(input))); + if (input->node()->kind() == prim::None) { + stack.emplace_back(IValue()); + } else { + stack.push_back(*(toIValue(input))); + } } op(stack); auto var_outputs = fmap(stack, [&](IValue v) -> IValue { @@ -48,7 +52,14 @@ std::vector runNode(Node* n) { } void propagateNode(Node* n) { - auto outputs = runNode(n); + std::vector outputs; + try { + outputs = runNode(n); + } catch (const c10::Error& e) { + // catch AT_ASSERT errors. This op may not be run reached, + // so catch the error here & leave the op in the graph + return; + } auto graph = n->owningGraph(); WithInsertPoint guard(n); for (size_t i = 0; i < outputs.size(); ++i) { @@ -119,7 +130,7 @@ void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse); void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) { bool constant_inputs = std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) { - return v->node()->kind() == prim::Constant; + return v->node()->kind() == prim::Constant || v->node()->kind() == prim::None; }); bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index ea6e5c6..54ca90d 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -782,7 +782,7 @@ RegisterOperators reg({ [](const Node* node) -> Operation { return [=](Stack& stack) { auto val = pop(stack); - JIT_ASSERTM(!val.isNone(), "Unwrapping null optional"); + AT_CHECK(!val.isNone(), "Unwrapping null optional"); push(stack, val); return 0; }; diff --git a/torch/nn/functional.py b/torch/nn/functional.py index dfa47a1..9f1494c 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1855,9 +1855,9 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non if full: mask = target > 1 loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask] - if reduction is 'none': + if reduction == 'none': ret = loss - if reduction is 'mean': + if reduction == 'mean': ret = torch.mean(loss) else: ret = torch.sum(loss) -- 2.7.4