protected:
void VisitDFPattern(const DFPattern& pattern) override {
- DFPatternVisitor::VisitDFPattern(pattern);
- auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
- graph_.node_map_[pattern] = node;
- graph_.topological_order_.push_back(node);
+ if (this->visited_.count(pattern.get()) == 0) {
+ DFPatternVisitor::VisitDFPattern(pattern);
+ auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
+ graph_.node_map_[pattern] = node;
+ graph_.topological_order_.push_back(node);
+ }
}
IndexedGraph<DFPattern> graph_;
size_t index_ = 0;
std::vector<Node*> outputs_;
/*! \brief The depth of the node in the dominator tree */
- size_t depth_;
+ size_t depth_ = 0;
/*! \brief The dominator parent/final user of the outputs of this node */
Node* dominator_parent_;
/*! \brief The nodes this node dominates */
return nullptr;
}
while (lhs != rhs) {
+ CHECK(lhs);
+ CHECK(rhs);
if (lhs->depth_ < rhs->depth_) {
rhs = rhs->dominator_parent_;
} else if (lhs->depth_ > rhs->depth_) {
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
+def test_nested_rewrite():
+ class PatternCallback(DFPatternCallback):
+ def __init__(self, pattern):
+ self.pattern = pattern
+
+ def callback(self, pre, post, node_map):
+ return post
+
+ def gen():
+ x = relay.var('x')
+ y = relay.var('y')
+ y_add = relay.add(y, y)
+ n0 = relay.add(x, y_add)
+ n1 = relay.add(x, n0)
+ return relay.add(n1, n0)
+
+ def pattern():
+ a = wildcard()
+ b = wildcard()
+ n0 = is_op('add')(a, b)
+ n1 = is_op('add')(n0, a)
+ return is_op('add')(n0, n1)
+
+ out = gen()
+ pat = pattern()
+ new_out = rewrite(PatternCallback(pat), out)
+
+ assert tvm.ir.structural_equal(out, new_out)
+
def test_not_fuse_multi_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
test_no_match_diamond()
test_match_fake_diamond()
test_rewrite()
+ test_nested_rewrite()
test_fuse_batchnorm()
test_no_fuse_batchnorm()
test_fuse_double_batchnorm()