+++ /dev/null
-graph(%input : Tensor,
- %opt.1 : Tensor?):
- %2 : None = prim::None()
- %3 : int = prim::Constant[value=1]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=4]()
- %x.1 : Tensor = aten::add(%input, %4, %3)
- %7 : bool = aten::__isnot__(%opt.1, %2)
- %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
- block0():
- %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
- %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
- %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
- -> (%opt.3, %x.2)
- block1():
- -> (%opt.1, %x.1)
- %13 : bool = aten::__is__(%opt.4, %2)
- %x : Tensor = prim::If(%13)
- block0():
- %x.4 : Tensor = aten::add(%x.3, %5, %3)
- -> (%x.4)
- block1():
- -> (%x.3)
- return (%x)
self.checkScript(func, inputs, optimize=True)
def test_if_is_none_dispatch(self):
- class Test(torch.jit.ScriptModule):
- __constants__ = ['b']
- def __init__(self, b=None):
- super(Test, self).__init__()
- self.b = b
+ @torch.jit.script
+ def test_lhs_none_rhs_none():
+ # LHS, RHS both alwaysNone, dispatch always_none_branch
+ # only emit one prim::Constant
+ if None is None:
+ return 1
+ elif None is not None:
+ return 2
+ else:
+ return 3
- @torch.jit.script_method
- def forward(self, input, opt=None):
- # type: (Tensor, Optional[Tensor]) -> Tensor
- x = input
- if self.b is not None:
- x = self.b(input)
+ self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
- if self.b is None:
- x = input + 2
+ @torch.jit.script
+ def test_lhs_opt_rhs_none(lhs=None):
+ # type: (Optional[Tensor]) -> int
+ # LHS maybeNone: emit normal if stmt that contains 3 constants
+ if lhs is not None:
+ return 2
+ elif lhs is None:
+ return 1
+ else:
+ return 3
- if opt is not None:
- opt = torch.jit._unwrap_optional(opt)
- x = opt + x
+ self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
- if opt is None:
- x = x + 4
+ @torch.jit.script
+ def test_lhs_none_rhs_opt(rhs=None):
+ # type: (Optional[Tensor]) -> int
+ # RHS maybeNone, emit normal if stmt that contains 3 constants
+ if None is rhs:
+ return 1
+ elif None is not rhs:
+ return 2
+ else:
+ return 3
- return x
+ self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
+
+ @torch.jit.script
+ def test_lhs_never_rhs_none(lhs):
+ # LHS neverNone, RHS alwaysNone dispatch never_none_branch
+ # only emit one prim::Constant
+ if lhs is None:
+ return 1
+ elif lhs is not None:
+ return 2
+ else:
+ return 3
+
+ self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
+
+ @torch.jit.script
+ def test_lhs_none_rhs_never(rhs):
+ # LHS alwaysNone, RHS neverNone dispatch never_none_branch
+ # only emit one prim::Constant
+ if None is rhs:
+ return 1
+ elif None is not rhs:
+ return 2
+ else:
+ return 3
- inputs = torch.zeros(1, 2)
- self.assertExpectedGraph(Test().graph)
- out = Test()(inputs)
- self.assertEqual(out, inputs + 6)
+ self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
def test_explicit_bool_cast(self):
with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
if x is None:
res = res + 1
else:
- res = torch.jit._unwrap_optional(x)
+ res = x
return res
fn = test_script_optional_tensor_none
if x is None:
res = res + 1.0
else:
- res = torch.jit._unwrap_optional(x)
+ res = x
return res
fn = test_script_optional_other_none