[Relay] Fix dataflow_pattern.rewrite() hang if Match in IR (#5680)
authorlixiaoquan <radioheads@163.com>
Fri, 29 May 2020 10:37:05 +0000 (18:37 +0800)
committerGitHub <noreply@github.com>
Fri, 29 May 2020 10:37:05 +0000 (19:37 +0900)
rewrite() quits only if graph stop changing, but ExprMutator
  always creates new Match node. This patch fixes this.

src/relay/ir/expr_functor.cc
tests/python/relay/test_dataflow_pattern.py

index 684dae7..c3d4781 100644 (file)
@@ -346,16 +346,28 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
 Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }
 
 Expr ExprMutator::VisitExpr_(const MatchNode* m) {
+  bool unchanged = true;
   std::vector<Clause> clauses;
   for (const Clause& p : m->clauses) {
-    clauses.push_back(VisitClause(p));
+    Clause c = VisitClause(p);
+    clauses.push_back(c);
+    unchanged &= c.same_as(p);
   }
-  return Match(Mutate(m->data), clauses, m->complete);
+  Expr data = Mutate(m->data);
+  unchanged &= data.same_as(m->data);
+  if (unchanged) {
+    return GetRef<Expr>(m);
+  }
+  return Match(data, clauses, m->complete);
 }
 
 Clause ExprMutator::VisitClause(const Clause& c) {
   Pattern p = VisitPattern(c->lhs);
-  return Clause(p, Mutate(c->rhs));
+  Expr rhs = Mutate(c->rhs);
+  if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) {
+    return c;
+  }
+  return Clause(p, rhs);
 }
 
 Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
index 5d91dcb..467e30b 100644 (file)
@@ -1140,6 +1140,18 @@ def test_partition_option():
     assert pattern2.match(relu)
     assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
 
+def test_match_match():
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    class TestRewrite(DFPatternCallback):
+        def __init__(self):
+            self.pattern = add_pattern
+        def callback(self, pre, post, node_map):
+            return post.args[0] - post.args[1]
+    mod = tvm.IRModule({})
+    tvm.relay.prelude.Prelude(mod)
+    # Apply rewrite on IR including relay.Match
+    out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
+    assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)
 
 if __name__ == "__main__":
     test_expr_pattern()
@@ -1196,3 +1208,4 @@ if __name__ == "__main__":
     test_partition_check()
     test_partition_check_types()
     test_partition_option()
+    test_match_match()