Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }
Expr ExprMutator::VisitExpr_(const MatchNode* m) {
+ bool unchanged = true;
std::vector<Clause> clauses;
for (const Clause& p : m->clauses) {
- clauses.push_back(VisitClause(p));
+ Clause c = VisitClause(p);
+ clauses.push_back(c);
+ unchanged &= c.same_as(p);
}
- return Match(Mutate(m->data), clauses, m->complete);
+ Expr data = Mutate(m->data);
+ unchanged &= data.same_as(m->data);
+ if (unchanged) {
+ return GetRef<Expr>(m);
+ }
+ return Match(data, clauses, m->complete);
}
Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
- return Clause(p, Mutate(c->rhs));
+ Expr rhs = Mutate(c->rhs);
+ if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) {
+ return c;
+ }
+ return Clause(p, rhs);
}
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
assert pattern2.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
+def test_match_match():
+ add_pattern = is_op('add')(wildcard(), wildcard())
+ class TestRewrite(DFPatternCallback):
+ def __init__(self):
+ self.pattern = add_pattern
+ def callback(self, pre, post, node_map):
+ return post.args[0] - post.args[1]
+ mod = tvm.IRModule({})
+ tvm.relay.prelude.Prelude(mod)
+ # Apply rewrite on IR including relay.Match
+ out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
+ assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)
if __name__ == "__main__":
test_expr_pattern()
test_partition_check()
test_partition_check_types()
test_partition_option()
+ test_match_match()