+++ /dev/null
-graph(%input : Tensor,
- %opt.1 : Tensor?):
- %2 : None = prim::Constant()
- %3 : int = prim::Constant[value=1]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=4]()
- %x.1 : Tensor = aten::add(%input, %4, %3)
- %7 : bool = aten::__isnot__(%opt.1, %2)
- %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
- block0():
- %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
- %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
- %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
- -> (%opt.3, %x.2)
- block1():
- -> (%opt.1, %x.1)
- %13 : bool = aten::__is__(%opt.4, %2)
- %x : Tensor = prim::If(%13)
- block0():
- %x.4 : Tensor = aten::add(%x.3, %5, %3)
- -> (%x.4)
- block1():
- -> (%x.3)
- return (%x)
+++ /dev/null
-graph(%x : Double(*, *)):
- %1 : int = prim::Constant[value=0]()
- %2 : bool = prim::Constant[value=1]()
- %c : Tensor[] = prim::If(%2)
- block0():
- %c.1 : Tensor[] = prim::ListConstruct(%x, %x)
- -> (%c.1)
- block1():
- %c.2 : Tensor[] = prim::ListConstruct(%x, %x, %x)
- -> (%c.2)
- %6 : Tensor = aten::cat(%c, %1)
- return (%6)
+++ /dev/null
-graph(%x.1 : Float(*, *),
- %y.1 : Long(*, *),
- %z.1 : Float(*, *)):
- %3 : bool = prim::Constant[value=1]()
- %x : Float(*, *), %y : Tensor, %z : Tensor = prim::If(%3)
- block0():
- -> (%x.1, %y.1, %z.1)
- block1():
- -> (%x.1, %x.1, %y.1)
- %7 : (Float(*, *), Tensor, Tensor) = prim::TupleConstruct(%x, %y, %z)
- return (%7)
+++ /dev/null
-graph(%target : Double(100),
- %indices.1 : Long(4),
- %rhs : Double(1, 1, 1, 4)):
- %3 : int = prim::Constant[value=4]()
- %4 : int[] = prim::ListConstruct(%3)
- %5 : Double(4) = aten::view(%rhs, %4)
- %6 : int = prim::Constant[value=4]()
- %7 : int = prim::Constant[value=0]()
- %8 : Device = prim::Constant[value="cpu"]()
- %9 : bool = prim::Constant[value=0]()
- %10 : bool = prim::Constant[value=0]()
- %indices : Long(4) = aten::to(%indices.1, %6, %7, %8, %9, %10)
- %12 : Tensor?[] = prim::ListConstruct(%indices)
- %13 : bool = prim::Constant[value=0]()
- %14 : Double(100) = aten::index_put_(%target, %12, %5, %13)
- return (%14)
+++ /dev/null
-graph(%target : Double(100),
- %indices.1 : Long(4),
- %rhs : Double(4)):
- %3 : int = prim::Constant[value=4]()
- %4 : int = prim::Constant[value=0]()
- %5 : Device = prim::Constant[value="cpu"]()
- %6 : bool = prim::Constant[value=0]()
- %7 : bool = prim::Constant[value=0]()
- %indices : Long(4) = aten::to(%indices.1, %3, %4, %5, %6, %7)
- %9 : Tensor?[] = prim::ListConstruct(%indices)
- %10 : bool = prim::Constant[value=0]()
- %11 : Double(100) = aten::index_put_(%target, %9, %rhs, %10)
- return (%11)
+++ /dev/null
-graph(%x : Double(2, 2),
- %y : Long(4)):
- %2 : int = prim::Constant[value=1]()
- %3 : Double(2, 4) = aten::index_select(%x, %2, %y)
- return (%3)
+++ /dev/null
-graph():
- %0 : int = prim::Constant[value=1]()
- %1 : Device = prim::Constant[value="cpu"]()
- %2 : int = prim::Constant[value=0]()
- %3 : int = prim::Constant[value=6]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=3]()
- %6 : int[] = prim::ListConstruct(%4, %5)
- %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
- %8 : int[] = prim::ListConstruct(%4, %5)
- %9 : Tensor = aten::rand(%8, %3, %2, %1)
- %a : Tensor = aten::add_(%a.1, %9, %0)
- return (%a)
+++ /dev/null
-graph():
- %0 : int = prim::Constant[value=1]()
- %1 : Device = prim::Constant[value="cpu"]()
- %2 : int = prim::Constant[value=0]()
- %3 : int = prim::Constant[value=6]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=3]()
- %6 : int[] = prim::ListConstruct(%4, %5)
- %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
- %8 : int[] = prim::ListConstruct(%4, %5)
- %9 : Tensor = aten::rand(%8, %3, %2, %1)
- %a.2 : Tensor = aten::add_(%a.1, %9, %0)
- %11 : int[] = prim::ListConstruct(%4, %5)
- %b.1 : Tensor = aten::rand(%11, %3, %2, %1)
- %13 : int[] = prim::ListConstruct(%4, %5)
- %14 : Tensor = aten::zeros(%13, %3, %2, %1)
- %15 : Tensor = aten::gt(%a.2, %14)
- %16 : bool = prim::Bool(%15)
- %b : Tensor = prim::If(%16)
- block0():
- %18 : int[] = prim::ListConstruct(%4, %5)
- %19 : Tensor = aten::rand(%18, %3, %2, %1)
- %b.2 : Tensor = aten::add_(%b.1, %19, %0)
- -> (%b.2)
- block1():
- -> (%b.1)
- return (%b)
+++ /dev/null
-graph(%a.1 : Tensor):
- %1 : None = prim::Constant()
- %2 : int = prim::Constant[value=1]()
- %3 : Device = prim::Constant[value="cpu"]()
- %4 : int = prim::Constant[value=0]()
- %5 : int = prim::Constant[value=6]()
- %6 : int = prim::Constant[value=2]()
- %7 : int = prim::Constant[value=3]()
- %8 : int[] = prim::ListConstruct(%6, %7)
- %9 : Tensor = aten::rand(%8, %5, %4, %3)
- %a : Tensor = aten::add_(%a.1, %9, %2)
- return (%1)
+++ /dev/null
-graph(%a : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Device = prim::Constant[value="cpu"]()
- %3 : int = prim::Constant[value=6]()
- %4 : int = prim::Constant[value=0]()
- %5 : int = prim::Constant[value=2]()
- %6 : int = prim::Constant[value=3]()
- %l : Tensor[] = prim::ListConstruct()
- %8 : Tensor[] = aten::append(%l, %a)
- %c.1 : Tensor = aten::select(%l, %4)
- %10 : int[] = prim::ListConstruct(%5, %6)
- %b : Tensor = aten::rand(%10, %3, %4, %2)
- %12 : int[] = prim::ListConstruct(%5, %6)
- %13 : Tensor = aten::rand(%12, %3, %4, %2)
- %c : Tensor = aten::add_(%c.1, %13, %1)
- return (%b)
+++ /dev/null
-graph(%a : Tensor):
- %1 : Device = prim::Constant[value="cpu"]()
- %2 : int = prim::Constant[value=6]()
- %i.1 : int = prim::Constant[value=0]()
- %4 : int = prim::Constant[value=2]()
- %5 : int = prim::Constant[value=3]()
- %6 : int = prim::Constant[value=9223372036854775807]()
- %7 : int = prim::Constant[value=1]()
- %l : Tensor[] = prim::ListConstruct()
- %9 : Tensor[] = aten::append(%l, %a)
- %10 : int[] = prim::ListConstruct(%4, %5)
- %b : Tensor = aten::rand(%10, %2, %i.1, %1)
- %12 : bool = aten::lt(%i.1, %7)
- %i : int = prim::Loop(%6, %12, %i.1)
- block0(%14 : int, %15 : int):
- %c.1 : Tensor = aten::select(%l, %i.1)
- %17 : int[] = prim::ListConstruct(%4, %5)
- %18 : Tensor = aten::rand(%17, %2, %i.1, %1)
- %c : Tensor = aten::add_(%c.1, %18, %7)
- %i.2 : int = aten::add(%15, %7)
- %21 : bool = aten::lt(%i.2, %7)
- -> (%21, %i.2)
- return (%b)
+++ /dev/null
-graph(%a : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : int = prim::Constant[value=2]()
- %3 : int = prim::Constant[value=0]()
- %4 : bool = prim::Bool(%a)
- %b : (int, int) = prim::If(%4)
- block0():
- %b.1 : (int, int) = prim::TupleConstruct(%1, %2)
- -> (%b.1)
- block1():
- %b.2 : (int, int) = prim::TupleConstruct(%3, %2)
- -> (%b.2)
- %8 : int = prim::TupleIndex[index=0](%b)
- %9 : int = prim::TupleIndex[index=1](%b)
- %10 : (int, int) = prim::TupleConstruct(%8, %9)
- return (%10)
+++ /dev/null
-graph(%a : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : int = prim::Constant[value=2]()
- %3 : int = prim::Constant[value=3]()
- %4 : int = prim::Constant[value=4]()
- %5 : bool = prim::Bool(%a)
- %b : (int, int, int, int) = prim::If(%5)
- block0():
- %b.1 : (int, int, int, int) = prim::TupleConstruct(%1, %2, %3, %4)
- -> (%b.1)
- block1():
- %b.2 : (int, int, int, int) = prim::TupleConstruct(%4, %3, %2, %1)
- -> (%b.2)
- %c : (int, int, int, int) = prim::TupleSlice[beg=0, end=4](%b)
- %e : (int, int) = prim::TupleSlice[beg=1, end=3](%c)
- return (%e)
print(c0)
return 1
- def test_if_list(self):
- # testing that different length lists don't throw error
+ def test_if_list_cat(self):
+ # testing that different length lists don't throw error on cat in shape prop
@torch.jit.script
def test_list(x):
- if True:
+ if bool(x.sum() < 1):
c = [x, x]
else:
c = [x, x, x]
b = torch.zeros(2, 4)
test_list.graph.propagate_shapes((b,), False)
- self.assertExpected(canonical(test_list.graph))
def test_if_supertype(self):
@torch.jit.script
def tensor_unifying(x, y, z):
-
# testing dynamic is appropriately set for y and z
if True:
x, y, z = x, y, z
c = torch.zeros(2, 4, dtype=torch.float)
tensor_unifying.graph.propagate_shapes((a, b, c), False)
- self.assertExpected(canonical(tensor_unifying.graph))
+ if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
+ self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
+ self.assertTrue(if_outputs[1].type().str() == "Tensor")
+ self.assertTrue(if_outputs[2].type().str() == "Tensor")
def test_list_unify(self):
# allowing a unififed int?[] would cause a runtime error b/c
a = torch.zeros(2, 2)
b = torch.zeros(4, dtype=torch.long)
torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
- self.assertExpected(canonical(foo.graph))
+ FileCheck().check("Double(2, 4)").run(str(foo.graph))
def test_onnx_export_speculate(self):
target[indices] = rhs
return target
- self.assertExpectedGraph(test_index_put.graph)
+ FileCheck().check("aten::view").check("index_put_").run(str(test_index_put.graph))
def test_index_put_trace_without_view(self):
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
target[indices] = rhs
return target
- self.assertExpectedGraph(test_index_put.graph)
+ FileCheck().check_not("aten::view").check("index_put_").run(str(test_index_put.graph))
def test_tuple_indexing(self):
def tuple_index(a):
b = (0, 2)
return b[-2], b[1]
+ self.checkScript(tuple_index, (torch.tensor([0]),))
self.checkScript(tuple_index, (torch.tensor([1]),))
self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
tuple_comp = torch.jit.script(tuple_index)
- self.assertExpectedGraph(tuple_comp.graph)
- self.assertEqual(tuple_comp(torch.tensor(1)), (1, 2))
+ FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
@torch.jit.script
else:
b = (4, 3, 2, 1)
c = b[-4:4]
- d = b[0:]
e = c[1:-1]
return e
self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
- tuple_graph = torch.jit.script(tuple_slice)
- self.assertExpectedGraph(tuple_graph.graph)
- self.run_pass('lower_all_tuples', tuple_graph.graph)
- self.assertTrue('Tuple' not in str(tuple_graph.graph))
+ tuple_graph = torch.jit.script(tuple_slice).graph
+ slices = tuple_graph.findAllNodes("prim::TupleSlice")
+ num_outputs = set(map(lambda x: len(x.output().type().elements()), slices))
+ # one tuple slice should have an output with 2 elements, other 4
+ self.assertTrue(num_outputs == set([2, 4]))
+ self.run_pass('lower_all_tuples', tuple_graph)
+ self.assertTrue('Tuple' not in str(tuple_graph))
tuple_comp = torch.jit.script(tuple_slice)
self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3))
# b should be cleaned up but not a
return a
- self.assertExpectedGraph(foo.graph)
+ FileCheck().check_count("aten::rand", 2, exactly=True) \
+ .check_count("aten::add", 1, exactly=True).run(str(foo.graph))
def test_mutable_dce_block(self):
@torch.jit.script
# a should be cleaned up but not b
return b
- self.assertExpectedGraph(foo.graph)
+ FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
+ .run(str(foo.graph))
def test_mutable_dce_graph_input(self):
@torch.jit.script
a += torch.rand(2, 3)
# shouldn't clean up `a` even though it's not used in the output
- self.assertExpectedGraph(foo.graph)
+ FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
def test_mutable_dce_list(self):
@torch.jit.script
c += torch.rand(2, 3)
return b
- self.assertExpectedGraph(foo.graph)
+ # c does not get cleaned up because there is a wildcard + mutation
+ FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
def test_mutable_dce_loop(self):
@torch.jit.script
i += 1
return b
- self.assertExpectedGraph(foo.graph)
+ FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::select") \
+ .check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
def test_mutable_dce_wildcards(self):
def fn():