Pool prim::None nodes (#15745)
authorElias Ellison <eellison@fb.com>
Mon, 7 Jan 2019 17:58:08 +0000 (09:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 7 Jan 2019 18:00:51 +0000 (10:00 -0800)
Summary:
Make the constant pooling pass pool prim::None nodes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15745

Differential Revision: D13583518

Pulled By: eellison

fbshipit-source-id: 7f8aa70522515805ab0991c6db3d96b5a96cdede

test/expect/TestScript.test_if_is_none_dispatch.expect
test/expect/TestScript.test_mutable_dce_graph_input.expect
test/test_jit.py
torch/csrc/jit/passes/constant_pooling.cpp

index 4109f5e..bc15fd3 100644 (file)
@@ -1,25 +1,24 @@
 graph(%input : Tensor
       %opt.1 : Tensor?) {
-  %2 : int = prim::Constant[value=1]()
-  %3 : int = prim::Constant[value=2]()
-  %4 : int = prim::Constant[value=4]()
-  %x.1 : Tensor = aten::add(%input, %3, %2)
-  %6 : None = prim::None()
-  %7 : bool = aten::__isnot__(%opt.1, %6)
+  %2 : None = prim::None()
+  %3 : int = prim::Constant[value=1]()
+  %4 : int = prim::Constant[value=2]()
+  %5 : int = prim::Constant[value=4]()
+  %x.1 : Tensor = aten::add(%input, %4, %3)
+  %7 : bool = aten::__isnot__(%opt.1, %2)
   %opt : Tensor?, %x.3 : Tensor = prim::If(%7)
     block0() {
       %opt.2 : Tensor = aten::_unwrap_optional(%opt.1)
-      %x.2 : Tensor = aten::add(%opt.2, %x.1, %2)
+      %x.2 : Tensor = aten::add(%opt.2, %x.1, %3)
       -> (%opt.2, %x.2)
     }
     block1() {
       -> (%opt.1, %x.1)
     }
-  %12 : None = prim::None()
-  %13 : bool = aten::__is__(%opt, %12)
-  %x : Tensor = prim::If(%13)
+  %12 : bool = aten::__is__(%opt, %2)
+  %x : Tensor = prim::If(%12)
     block0() {
-      %x.4 : Tensor = aten::add(%x.3, %4, %2)
+      %x.4 : Tensor = aten::add(%x.3, %5, %3)
       -> (%x.4)
     }
     block1() {
index 0ac2218..2ac2166 100644 (file)
@@ -1,13 +1,13 @@
 graph(%a.1 : Tensor) {
-  %1 : int = prim::Constant[value=1]()
-  %2 : Device = prim::Constant[value="cpu"]()
-  %3 : int = prim::Constant[value=0]()
-  %4 : int = prim::Constant[value=6]()
-  %5 : int = prim::Constant[value=2]()
-  %6 : int = prim::Constant[value=3]()
-  %7 : int[] = prim::ListConstruct(%5, %6)
-  %8 : Tensor = aten::rand(%7, %4, %3, %2)
-  %a : Tensor = aten::add_(%a.1, %8, %1)
-  %10 : None = prim::None()
-  return (%10);
+  %1 : None = prim::None()
+  %2 : int = prim::Constant[value=1]()
+  %3 : Device = prim::Constant[value="cpu"]()
+  %4 : int = prim::Constant[value=0]()
+  %5 : int = prim::Constant[value=6]()
+  %6 : int = prim::Constant[value=2]()
+  %7 : int = prim::Constant[value=3]()
+  %8 : int[] = prim::ListConstruct(%6, %7)
+  %9 : Tensor = aten::rand(%8, %5, %4, %3)
+  %a : Tensor = aten::add_(%a.1, %9, %2)
+  return (%1);
 }
index caf2427..6d5b640 100644 (file)
@@ -3339,6 +3339,25 @@ a")
         self.run_pass('constant_pooling', graph)
         self.assertExpectedGraph(graph)
 
+    def test_constant_pooling_none(self):
+        @torch.jit.script
+        def typed_nones(a=None, b=None, c=None):
+            # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] # noqa
+            return a, b, c
+
+        @torch.jit.script
+        def test(a):
+            # type: (bool) -> None
+            if a:
+                print(typed_nones())
+            else:
+                print(typed_nones())
+
+        graph_str = str(test.graph)
+        self.assertTrue(graph_str.count("bool? = prim::None") == 1)
+        self.assertTrue(graph_str.count("int? = prim::None") == 1)
+        self.assertTrue(graph_str.count("None = prim::None") == 1)
+
     def test_literal(self):
         def func1(a, b):
             c = a, b
index af6de16..907d074 100644 (file)
@@ -26,7 +26,7 @@ void ConstantPooling(
       continue;
     }
 
-    if (node->kind() != prim::Constant) {
+    if (node->kind() != prim::Constant && node->kind() != prim::None) {
       continue;
     }