def test_constant_prop_simple(self):
@torch.jit.script
- def constant_prop(input_tensor):
+ def constant_prop(input_int):
+ # type: (int) -> int
a = 2 * 3
b = a + 2
- return b + input_tensor
+ return b - input_int
- x = torch.tensor(2)
- out_ref = constant_prop(x)
+ out_ref = constant_prop(2)
self.run_pass('constant_propagation', constant_prop.graph)
- out_test = constant_prop(torch.tensor(2))
+ out_test = constant_prop(2)
self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(constant_prop.graph))
+ graph_str = str(constant_prop.graph)
+ self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
+ const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
+ self.assertEqual(const, 8)
def test_constant_prop_nested(self):
@torch.jit.script
self.run_pass('constant_propagation', constant_prop.graph)
out_test = constant_prop(torch.tensor(2))
self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(constant_prop.graph))
+ if_node = constant_prop.graph.findNode("prim::If")
+ for block in if_node.blocks():
+ for node in block.nodes():
+ self.assertTrue(node.kind() == "prim::Constant")
def test_constant_prop_print(self):
@torch.jit.script
return b + input_tensor
self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ graph = constant_prop.graph
+ print_node = graph.findNode("prim::Print")
+ self.assertTrue(print_node.input().toIValue() == 6)
def test_constant_prop_rand(self):
@torch.jit.script
return b
self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ self.assertTrue("aten::randn" in str(constant_prop.graph))
def test_constant_prop_none(self):
@torch.jit.script
c2 = c2 + 4 # set to 5
return a + c0 + c1 + c2
- self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ graph = constant_prop.graph
+ self.run_pass('constant_propagation', graph)
+ ifs = graph.findAllNodes("prim::If", recurse=False)
+ snd_if_inlined = len(ifs) == 1
+ self.assertTrue(snd_if_inlined)
+ first_if = ifs[0]
+ self.assertTrue(first_if.outputsSize() == 2)
+ second_if = first_if.findNode("prim::If", recurse=False)
+ self.assertTrue(second_if.outputsSize() == 1)
+ self.assertTrue(second_if.findNode("prim::If") is None)
def test_constant_prop_loop_constant(self):
@torch.jit.script
}
}
-std::vector<Node*> findAllNodes(Block* block, Symbol kind) {
+std::vector<Node*> findAllNodes(
+ c10::ArrayRef<torch::jit::Block*> blocks,
+ Symbol kind,
+ bool recurse = true) {
std::vector<Node*> ret;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- auto nodes = findAllNodes(b, kind);
- ret.insert(ret.end(), nodes.begin(), nodes.end());
- }
- if (n->kind() == kind) {
- ret.push_back(n);
+ for (Block* block : blocks) {
+ for (Node* n : block->nodes()) {
+ if (n->kind() == kind) {
+ ret.push_back(n);
+ }
+ if (recurse) {
+ auto nodes = findAllNodes(n->blocks(), kind, recurse);
+ ret.insert(ret.end(), nodes.begin(), nodes.end());
+ }
}
}
return ret;
}
-Node* findNode(Block* block, Symbol kind) {
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- auto node = findNode(b, kind);
- if (node != nullptr) {
- return node;
+std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) {
+ std::vector<Block*> blocks = {block};
+ return findAllNodes(blocks, kind, recurse);
+}
+
+Node* findNode(
+ c10::ArrayRef<torch::jit::Block*> blocks,
+ Symbol kind,
+ bool recurse = true) {
+ for (Block* block : blocks) {
+ for (Node* n : block->nodes()) {
+ if (n->kind() == kind) {
+ return n;
+ }
+ if (recurse) {
+ auto node = findNode(n->blocks(), kind, recurse);
+ if (node != nullptr) {
+ return node;
+ }
}
- }
- if (n->kind() == kind) {
- return n;
}
}
return nullptr;
}
+Node* findNode(Block* block, Symbol kind, bool recurse = true) {
+ std::vector<Block*> blocks = {block};
+ return findNode(blocks, kind, recurse);
+}
+
// execute a Python function, used for Ops we can't optimize but that we want to
// optimize around
struct ConcretePythonOp : public PythonOp {
})
.def(
"findNode",
- [](Graph& g, const std::string& kind) {
- return findNode(g.block(), Symbol::fromQualString(kind));
- })
+ [](Graph& g, const std::string& kind, bool recurse) {
+ return findNode(g.block(), Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
.def(
"findAllNodes",
- [](Graph& g, const std::string& kind) {
- return findAllNodes(g.block(), Symbol::fromQualString(kind));
- })
+ [](Graph& g, const std::string& kind, bool recurse) {
+ return findAllNodes(
+ g.block(), Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
.def("addInput", [](Graph& g) { return g.addInput(); })
.def("copy", [](Graph& g) { return g.copy(); })
.GS(eraseInput)
py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
.def("nodes", [](Block& b) {
return py::make_iterator(b.nodes().begin(), b.nodes().end());
- });
+ })
+ .def(
+ "findNode",
+ [](Block& b, const std::string& kind, bool recurse) {
+ return findNode(&b, Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+ .def(
+ "findAllNodes",
+ [](Block& b, const std::string& kind, bool recurse) {
+ return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true);
+
#define NS(name) def(#name, &Node ::name)
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
.def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
.def(
"findNode",
- [](Node& n, const std::string& kind) {
- Node* node;
- for (Block* b : n.blocks()) {
- node = findNode(b, Symbol::fromQualString(kind));
- if (node != nullptr) {
- return node;
- }
- }
- return node;
- })
+ [](Node& n, const std::string& kind, bool recurse) {
+ return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
.def(
"findAllNodes",
- [](Node& n, const std::string& kind) {
- std::vector<Node*> ret;
- for (Block* b : n.blocks()) {
- auto nodes = findAllNodes(b, Symbol::fromQualString(kind));
- ret.insert(ret.end(), nodes.begin(), nodes.end());
- }
- return ret;
- })
+ [](Node& n, const std::string& kind, bool recurse) {
+ return findAllNodes(
+ n.blocks(), Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
+ .def("input", [](Node& n) { return n.input(); })
.def("output", [](Node& n) { return n.output(); })
.NS(addInput)
.NS(replaceInput)