From: Matthew Brookhart Date: Fri, 26 Jun 2020 14:22:43 +0000 (-0700) Subject: Add TupleGetItem to CSE (#5931) X-Git-Tag: upstream/0.7.0~490 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c9203c7e2825fb2da93da4f9b88c164c31f1650e;p=platform%2Fupstream%2Ftvm.git Add TupleGetItem to CSE (#5931) * Add TupleGetItem to CSE * rename a local variable --- diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 8f7375c..dc3f77e 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -58,27 +58,52 @@ class CommonSubexprEliminator : public ExprMutator { auto it = expr_map_.find(new_call->op); if (it != expr_map_.end()) { - for (const CallNode* candidate : it->second) { - bool is_equivalent = true; - if (!attrs_equal(new_call->attrs, candidate->attrs)) { - continue; + for (const Expr& candidate_expr : it->second) { + if (const CallNode* candidate = candidate_expr.as()) { + bool is_equivalent = true; + if (!attrs_equal(new_call->attrs, candidate->attrs)) { + continue; + } + for (size_t i = 0; i < new_call->args.size(); i++) { + if (!new_call->args[i].same_as(candidate->args[i]) && + !IsEqualScalar(new_call->args[i], candidate->args[i])) { + is_equivalent = false; + break; + } + } + if (!is_equivalent) continue; + return GetRef(candidate); } - for (size_t i = 0; i < new_call->args.size(); i++) { - if (!new_call->args[i].same_as(candidate->args[i]) && - !IsEqualScalar(new_call->args[i], candidate->args[i])) { - is_equivalent = false; - break; + } + } + expr_map_[new_call->op].push_back(new_expr); + return new_expr; + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_expr = ExprMutator::VisitExpr_(op); + const TupleGetItemNode* new_tuple_item = new_expr.as(); + CHECK(new_tuple_item); + + if (fskip_ != nullptr && fskip_(new_expr)) { + return new_expr; + } + + auto it = expr_map_.find(new_tuple_item->tuple); + if (it != expr_map_.end()) { + for (const Expr& candidate_expr : it->second) { + if (const TupleGetItemNode* candidate = candidate_expr.as()) { + if (new_tuple_item->index == candidate->index) { + return GetRef(candidate); } } - if (!is_equivalent) continue; - return GetRef(candidate); } } - expr_map_[new_call->op].push_back(new_call); + expr_map_[new_tuple_item->tuple].push_back(new_expr); return new_expr; } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; runtime::TypedPackedFunc fskip_; }; diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index 7af524d..45d21a4 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -84,6 +84,35 @@ def test_callback(): z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) assert tvm.ir.structural_equal(z, expected()) +def test_tuple_get_time(): + def before(): + x = relay.var('x', shape=(1, 16, 1, 1)) + var = relay.var('var', shape=(16,)) + mean = relay.var('mean', shape=(16,)) + beta = relay.var('beta', shape=(16,)) + gamma = relay.var('gamma', shape=(16,)) + BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) + T1 = BN[0] + T2 = BN[0] + add = T1 + T2 + f = relay.Function([x, var, mean, beta, gamma], add) + return f + + def expected(): + x = relay.var('x', shape=(1, 16, 1, 1)) + var = relay.var('var', shape=(16,)) + mean = relay.var('mean', shape=(16,)) + beta = relay.var('beta', shape=(16,)) + gamma = relay.var('gamma', shape=(16,)) + BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) + T1 = BN[0] + add = T1 + T1 + f = relay.Function([x, var, mean, beta, gamma], add) + return run_opt_pass(f, transform.InferType()) + + z = before() + z = run_opt_pass(z, transform.EliminateCommonSubexpr()) + assert tvm.ir.structural_equal(z, expected()) if __name__ == "__main__": test_simple()