more test coverage on emitIf none dispatch (#16794)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Fri, 15 Feb 2019 05:37:08 +0000 (21:37 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 05:39:55 +0000 (21:39 -0800)
Summary:
Follow up of #14533, add more test coverage for emitif metaprogramming conditions. Also delete some unwrap optional usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16794

Differential Revision: D14096868

Pulled By: wanchaol

fbshipit-source-id: ee1cec609c58d0dd65211249a90207be06649e71

test/expect/TestScript.test_if_is_none_dispatch.expect [deleted file]
test/test_jit.py

diff --git a/test/expect/TestScript.test_if_is_none_dispatch.expect b/test/expect/TestScript.test_if_is_none_dispatch.expect
deleted file mode 100644 (file)
index f9d9fd4..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-graph(%input : Tensor,
-      %opt.1 : Tensor?):
-  %2 : None = prim::None()
-  %3 : int = prim::Constant[value=1]()
-  %4 : int = prim::Constant[value=2]()
-  %5 : int = prim::Constant[value=4]()
-  %x.1 : Tensor = aten::add(%input, %4, %3)
-  %7 : bool = aten::__isnot__(%opt.1, %2)
-  %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
-    block0():
-      %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
-      %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
-      %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
-      -> (%opt.3, %x.2)
-    block1():
-      -> (%opt.1, %x.1)
-  %13 : bool = aten::__is__(%opt.4, %2)
-  %x : Tensor = prim::If(%13)
-    block0():
-      %x.4 : Tensor = aten::add(%x.3, %5, %3)
-      -> (%x.4)
-    block1():
-      -> (%x.3)
-  return (%x)
index 344ad68..26ec716 100644 (file)
@@ -4203,36 +4203,71 @@ a")
         self.checkScript(func, inputs, optimize=True)
 
     def test_if_is_none_dispatch(self):
-        class Test(torch.jit.ScriptModule):
-            __constants__ = ['b']
 
-            def __init__(self, b=None):
-                super(Test, self).__init__()
-                self.b = b
+        @torch.jit.script
+        def test_lhs_none_rhs_none():
+            # LHS, RHS both alwaysNone, dispatch always_none_branch
+            # only emit one prim::Constant
+            if None is None:
+                return 1
+            elif None is not None:
+                return 2
+            else:
+                return 3
 
-            @torch.jit.script_method
-            def forward(self, input, opt=None):
-                # type: (Tensor, Optional[Tensor]) -> Tensor
-                x = input
-                if self.b is not None:
-                    x = self.b(input)
+        self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
 
-                if self.b is None:
-                    x = input + 2
+        @torch.jit.script
+        def test_lhs_opt_rhs_none(lhs=None):
+            # type: (Optional[Tensor]) -> int
+            # LHS maybeNone: emit normal if stmt that contains 3 constants
+            if lhs is not None:
+                return 2
+            elif lhs is None:
+                return 1
+            else:
+                return 3
 
-                if opt is not None:
-                    opt = torch.jit._unwrap_optional(opt)
-                    x = opt + x
+        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
 
-                if opt is None:
-                    x = x + 4
+        @torch.jit.script
+        def test_lhs_none_rhs_opt(rhs=None):
+            # type: (Optional[Tensor]) -> int
+            # RHS maybeNone, emit normal if stmt that contains 3 constants
+            if None is rhs:
+                return 1
+            elif None is not rhs:
+                return 2
+            else:
+                return 3
 
-                return x
+        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
+
+        @torch.jit.script
+        def test_lhs_never_rhs_none(lhs):
+            # LHS neverNone, RHS alwaysNone dispatch never_none_branch
+            # only emit one prim::Constant
+            if lhs is None:
+                return 1
+            elif lhs is not None:
+                return 2
+            else:
+                return 3
+
+        self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
+
+        @torch.jit.script
+        def test_lhs_none_rhs_never(rhs):
+            # LHS alwaysNone, RHS neverNone dispatch never_none_branch
+            # only emit one prim::Constant
+            if None is rhs:
+                return 1
+            elif None is not rhs:
+                return 2
+            else:
+                return 3
 
-        inputs = torch.zeros(1, 2)
-        self.assertExpectedGraph(Test().graph)
-        out = Test()(inputs)
-        self.assertEqual(out, inputs + 6)
+        self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
 
     def test_explicit_bool_cast(self):
         with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
@@ -4507,7 +4542,7 @@ a")
             if x is None:
                 res = res + 1
             else:
-                res = torch.jit._unwrap_optional(x)
+                res = x
             return res
 
         fn = test_script_optional_tensor_none
@@ -4522,7 +4557,7 @@ a")
             if x is None:
                 res = res + 1.0
             else:
-                res = torch.jit._unwrap_optional(x)
+                res = x
             return res
 
         fn = test_script_optional_other_none