Constant prop prim::None (#15979)
authorElias Ellison <eellison@fb.com>
Tue, 15 Jan 2019 18:56:17 +0000 (10:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 19:34:51 +0000 (11:34 -0800)
Summary:
Previously we were only constant propping prim::Constants, but we should be constant propping prim::None as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15979

Differential Revision: D13664692

Pulled By: eellison

fbshipit-source-id: 01839403576c21fc030c427e49275b8e1210fa8f

test/test_jit.py
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/nn/functional.py

index 22e0d4d..43a7407 100644 (file)
@@ -1699,6 +1699,26 @@ class TestJit(JitTestCase):
         self.run_pass('constant_propagation', constant_prop.graph)
         self.assertExpected(canonical(constant_prop.graph))
 
+    def test_constant_prop_none(self):
+        @torch.jit.script
+        def typed_none():
+            # type: () -> Optional[int]
+            return None
+
+        @torch.jit.script
+        def constant_prop():
+            a = typed_none()
+            b = typed_none()
+            if (a is None and b is None):
+                a = 2
+            else:
+                a = 1
+            return a
+
+        self.run_pass('constant_propagation', constant_prop.graph)
+        graph_str = str(constant_prop.graph)
+        self.assertTrue(graph_str.count("prim::None") == 0)
+
     def test_trace_records_names(self):
         def foo(bar, baz):
             baz = bar + 3
index b8d6f2b..2665982 100644 (file)
@@ -29,7 +29,11 @@ std::vector<IValue> runNode(Node* n) {
   auto op = getOperation(n);
   Stack stack;
   for (auto input : n->inputs()) {
-    stack.push_back(*(toIValue(input)));
+    if (input->node()->kind() == prim::None) {
+      stack.emplace_back(IValue());
+    } else {
+      stack.push_back(*(toIValue(input)));
+    }
   }
   op(stack);
   auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
@@ -48,7 +52,14 @@ std::vector<IValue> runNode(Node* n) {
 }
 
 void propagateNode(Node* n) {
-  auto outputs = runNode(n);
+  std::vector<IValue> outputs;
+  try {
+    outputs = runNode(n);
+  } catch (const c10::Error& e) {
+    // catch AT_ASSERT errors. This op may not be run reached,
+    // so catch the error here & leave the op in the graph
+    return;
+  }
   auto graph = n->owningGraph();
   WithInsertPoint guard(n);
   for (size_t i = 0; i < outputs.size(); ++i) {
@@ -119,7 +130,7 @@ void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse);
 void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
   bool constant_inputs =
       std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
-        return v->node()->kind() == prim::Constant;
+        return v->node()->kind() == prim::Constant || v->node()->kind() == prim::None;
       });
   bool supported_node = !n->kind().is_onnx() &&
       skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
index ea6e5c6..54ca90d 100644 (file)
@@ -782,7 +782,7 @@ RegisterOperators reg({
         [](const Node* node) -> Operation {
           return [=](Stack& stack) {
             auto val = pop(stack);
-            JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
+            AT_CHECK(!val.isNone(), "Unwrapping null optional");
             push(stack, val);
             return 0;
           };
index dfa47a1..9f1494c 100644 (file)
@@ -1855,9 +1855,9 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
     if full:
         mask = target > 1
         loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
-    if reduction is 'none':
+    if reduction == 'none':
         ret = loss
-    if reduction is 'mean':
+    if reduction == 'mean':
         ret = torch.mean(loss)
     else:
         ret = torch.sum(loss)