graph(%input : Tensor
%opt.1 : Tensor?) {
- %2 : int = prim::Constant[value=1]()
- %3 : int = prim::Constant[value=2]()
- %4 : int = prim::Constant[value=4]()
- %x.1 : Tensor = aten::add(%input, %3, %2)
- %6 : None = prim::None()
- %7 : bool = aten::__isnot__(%opt.1, %6)
+ %2 : None = prim::None()
+ %3 : int = prim::Constant[value=1]()
+ %4 : int = prim::Constant[value=2]()
+ %5 : int = prim::Constant[value=4]()
+ %x.1 : Tensor = aten::add(%input, %4, %3)
+ %7 : bool = aten::__isnot__(%opt.1, %2)
%opt : Tensor?, %x.3 : Tensor = prim::If(%7)
block0() {
%opt.2 : Tensor = aten::_unwrap_optional(%opt.1)
- %x.2 : Tensor = aten::add(%opt.2, %x.1, %2)
+ %x.2 : Tensor = aten::add(%opt.2, %x.1, %3)
-> (%opt.2, %x.2)
}
block1() {
-> (%opt.1, %x.1)
}
- %12 : None = prim::None()
- %13 : bool = aten::__is__(%opt, %12)
- %x : Tensor = prim::If(%13)
+ %12 : bool = aten::__is__(%opt, %2)
+ %x : Tensor = prim::If(%12)
block0() {
- %x.4 : Tensor = aten::add(%x.3, %4, %2)
+ %x.4 : Tensor = aten::add(%x.3, %5, %3)
-> (%x.4)
}
block1() {
graph(%a.1 : Tensor) {
- %1 : int = prim::Constant[value=1]()
- %2 : Device = prim::Constant[value="cpu"]()
- %3 : int = prim::Constant[value=0]()
- %4 : int = prim::Constant[value=6]()
- %5 : int = prim::Constant[value=2]()
- %6 : int = prim::Constant[value=3]()
- %7 : int[] = prim::ListConstruct(%5, %6)
- %8 : Tensor = aten::rand(%7, %4, %3, %2)
- %a : Tensor = aten::add_(%a.1, %8, %1)
- %10 : None = prim::None()
- return (%10);
+ %1 : None = prim::None()
+ %2 : int = prim::Constant[value=1]()
+ %3 : Device = prim::Constant[value="cpu"]()
+ %4 : int = prim::Constant[value=0]()
+ %5 : int = prim::Constant[value=6]()
+ %6 : int = prim::Constant[value=2]()
+ %7 : int = prim::Constant[value=3]()
+ %8 : int[] = prim::ListConstruct(%6, %7)
+ %9 : Tensor = aten::rand(%8, %5, %4, %3)
+ %a : Tensor = aten::add_(%a.1, %9, %2)
+ return (%1);
}
self.run_pass('constant_pooling', graph)
self.assertExpectedGraph(graph)
+ def test_constant_pooling_none(self):
+ @torch.jit.script
+ def typed_nones(a=None, b=None, c=None):
+ # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] # noqa
+ return a, b, c
+
+ @torch.jit.script
+ def test(a):
+ # type: (bool) -> None
+ if a:
+ print(typed_nones())
+ else:
+ print(typed_nones())
+
+ graph_str = str(test.graph)
+ self.assertTrue(graph_str.count("bool? = prim::None") == 1)
+ self.assertTrue(graph_str.count("int? = prim::None") == 1)
+ self.assertTrue(graph_str.count("None = prim::None") == 1)
+
def test_literal(self):
def func1(a, b):
c = a, b
continue;
}
- if (node->kind() != prim::Constant) {
+ if (node->kind() != prim::Constant && node->kind() != prim::None) {
continue;
}