fix pattern topological order (#5612)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Mon, 18 May 2020 17:12:49 +0000 (10:12 -0700)
committerGitHub <noreply@github.com>
Mon, 18 May 2020 17:12:49 +0000 (10:12 -0700)
src/relay/ir/indexed_graph.cc
src/relay/ir/indexed_graph.h
tests/python/relay/test_dataflow_pattern.py

index 79ec574..7f7a5ff 100644 (file)
@@ -191,10 +191,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
 
    protected:
     void VisitDFPattern(const DFPattern& pattern) override {
-      DFPatternVisitor::VisitDFPattern(pattern);
-      auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
-      graph_.node_map_[pattern] = node;
-      graph_.topological_order_.push_back(node);
+      if (this->visited_.count(pattern.get()) == 0) {
+        DFPatternVisitor::VisitDFPattern(pattern);
+        auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
+        graph_.node_map_[pattern] = node;
+        graph_.topological_order_.push_back(node);
+      }
     }
     IndexedGraph<DFPattern> graph_;
     size_t index_ = 0;
index d252434..022eb3b 100644 (file)
@@ -69,7 +69,7 @@ class IndexedGraph {
     std::vector<Node*> outputs_;
 
     /*! \brief The depth of the node in the dominator tree */
-    size_t depth_;
+    size_t depth_ = 0;
     /*! \brief The dominator parent/final user of the outputs of this node */
     Node* dominator_parent_;
     /*! \brief The nodes this node dominates */
@@ -115,6 +115,8 @@ class IndexedGraph {
       return nullptr;
     }
     while (lhs != rhs) {
+      CHECK(lhs);
+      CHECK(rhs);
       if (lhs->depth_ < rhs->depth_) {
         rhs = rhs->dominator_parent_;
       } else if (lhs->depth_ > rhs->depth_) {
index a93a39b..41b3d6d 100644 (file)
@@ -425,6 +425,35 @@ def test_rewrite():
     out = rewrite(TestRewrite(), x + y)
     assert sub_pattern.match(out)
 
+def test_nested_rewrite():
+    class PatternCallback(DFPatternCallback):
+        def __init__(self, pattern):
+            self.pattern = pattern
+
+        def callback(self, pre, post, node_map):
+            return post
+
+    def gen():
+        x = relay.var('x')
+        y = relay.var('y')
+        y_add = relay.add(y, y)
+        n0 = relay.add(x, y_add)
+        n1 = relay.add(x, n0)
+        return relay.add(n1, n0)
+
+    def pattern():
+        a = wildcard()
+        b = wildcard()
+        n0 = is_op('add')(a, b)
+        n1 = is_op('add')(n0, a)
+        return is_op('add')(n0, n1)
+
+    out = gen()
+    pat = pattern()
+    new_out = rewrite(PatternCallback(pat), out)
+
+    assert tvm.ir.structural_equal(out, new_out)
+
 def test_not_fuse_multi_diamond():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -838,6 +867,7 @@ if __name__ == "__main__":
     test_no_match_diamond()
     test_match_fake_diamond()
     test_rewrite()
+    test_nested_rewrite()
     test_fuse_batchnorm()
     test_no_fuse_batchnorm()
     test_fuse_double_batchnorm()