From: Wanchao Liang Date: Fri, 15 Feb 2019 05:37:08 +0000 (-0800) Subject: more test coverage on emitIf none dispatch (#16794) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1270 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5f6ecd14c4aadb195cfa617c4dbca21dcb504594;p=platform%2Fupstream%2Fpytorch.git more test coverage on emitIf none dispatch (#16794) Summary: Follow up of #14533, add more test coverage for emitif metaprogramming conditions. Also delete some unwrap optional usage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16794 Differential Revision: D14096868 Pulled By: wanchaol fbshipit-source-id: ee1cec609c58d0dd65211249a90207be06649e71 --- diff --git a/test/expect/TestScript.test_if_is_none_dispatch.expect b/test/expect/TestScript.test_if_is_none_dispatch.expect deleted file mode 100644 index f9d9fd4..0000000 --- a/test/expect/TestScript.test_if_is_none_dispatch.expect +++ /dev/null @@ -1,24 +0,0 @@ -graph(%input : Tensor, - %opt.1 : Tensor?): - %2 : None = prim::None() - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Constant[value=2]() - %5 : int = prim::Constant[value=4]() - %x.1 : Tensor = aten::add(%input, %4, %3) - %7 : bool = aten::__isnot__(%opt.1, %2) - %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7) - block0(): - %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1) - %opt.3 : Tensor = aten::_unwrap_optional(%opt.2) - %x.2 : Tensor = aten::add(%opt.3, %x.1, %3) - -> (%opt.3, %x.2) - block1(): - -> (%opt.1, %x.1) - %13 : bool = aten::__is__(%opt.4, %2) - %x : Tensor = prim::If(%13) - block0(): - %x.4 : Tensor = aten::add(%x.3, %5, %3) - -> (%x.4) - block1(): - -> (%x.3) - return (%x) diff --git a/test/test_jit.py b/test/test_jit.py index 344ad68..26ec716 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4203,36 +4203,71 @@ a") self.checkScript(func, inputs, optimize=True) def test_if_is_none_dispatch(self): - class Test(torch.jit.ScriptModule): - __constants__ = ['b'] - def __init__(self, b=None): - super(Test, self).__init__() - self.b = b + @torch.jit.script + def test_lhs_none_rhs_none(): + # LHS, RHS both alwaysNone, dispatch always_none_branch + # only emit one prim::Constant + if None is None: + return 1 + elif None is not None: + return 2 + else: + return 3 - @torch.jit.script_method - def forward(self, input, opt=None): - # type: (Tensor, Optional[Tensor]) -> Tensor - x = input - if self.b is not None: - x = self.b(input) + self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1) - if self.b is None: - x = input + 2 + @torch.jit.script + def test_lhs_opt_rhs_none(lhs=None): + # type: (Optional[Tensor]) -> int + # LHS maybeNone: emit normal if stmt that contains 3 constants + if lhs is not None: + return 2 + elif lhs is None: + return 1 + else: + return 3 - if opt is not None: - opt = torch.jit._unwrap_optional(opt) - x = opt + x + self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) - if opt is None: - x = x + 4 + @torch.jit.script + def test_lhs_none_rhs_opt(rhs=None): + # type: (Optional[Tensor]) -> int + # RHS maybeNone, emit normal if stmt that contains 3 constants + if None is rhs: + return 1 + elif None is not rhs: + return 2 + else: + return 3 - return x + self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) + + @torch.jit.script + def test_lhs_never_rhs_none(lhs): + # LHS neverNone, RHS alwaysNone dispatch never_none_branch + # only emit one prim::Constant + if lhs is None: + return 1 + elif lhs is not None: + return 2 + else: + return 3 + + self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1) + + @torch.jit.script + def test_lhs_none_rhs_never(rhs): + # LHS alwaysNone, RHS neverNone dispatch never_none_branch + # only emit one prim::Constant + if None is rhs: + return 1 + elif None is not rhs: + return 2 + else: + return 3 - inputs = torch.zeros(1, 2) - self.assertExpectedGraph(Test().graph) - out = Test()(inputs) - self.assertEqual(out, inputs + 6) + self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1) def test_explicit_bool_cast(self): with self.assertRaisesRegex(RuntimeError, "expected a boolean"): @@ -4507,7 +4542,7 @@ a") if x is None: res = res + 1 else: - res = torch.jit._unwrap_optional(x) + res = x return res fn = test_script_optional_tensor_none @@ -4522,7 +4557,7 @@ a") if x is None: res = res + 1.0 else: - res = torch.jit._unwrap_optional(x) + res = x return res fn = test_script_optional_other_none