auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
- for (const CallNode* candidate : it->second) {
- bool is_equivalent = true;
- if (!attrs_equal(new_call->attrs, candidate->attrs)) {
- continue;
+ for (const Expr& candidate_expr : it->second) {
+ if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
+ bool is_equivalent = true;
+ if (!attrs_equal(new_call->attrs, candidate->attrs)) {
+ continue;
+ }
+ for (size_t i = 0; i < new_call->args.size(); i++) {
+ if (!new_call->args[i].same_as(candidate->args[i]) &&
+ !IsEqualScalar(new_call->args[i], candidate->args[i])) {
+ is_equivalent = false;
+ break;
+ }
+ }
+ if (!is_equivalent) continue;
+ return GetRef<Call>(candidate);
}
- for (size_t i = 0; i < new_call->args.size(); i++) {
- if (!new_call->args[i].same_as(candidate->args[i]) &&
- !IsEqualScalar(new_call->args[i], candidate->args[i])) {
- is_equivalent = false;
- break;
+ }
+ }
+ expr_map_[new_call->op].push_back(new_expr);
+ return new_expr;
+ }
+
+ Expr VisitExpr_(const TupleGetItemNode* op) final {
+ Expr new_expr = ExprMutator::VisitExpr_(op);
+ const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
+ CHECK(new_tuple_item);
+
+ if (fskip_ != nullptr && fskip_(new_expr)) {
+ return new_expr;
+ }
+
+ auto it = expr_map_.find(new_tuple_item->tuple);
+ if (it != expr_map_.end()) {
+ for (const Expr& candidate_expr : it->second) {
+ if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) {
+ if (new_tuple_item->index == candidate->index) {
+ return GetRef<Expr>(candidate);
}
}
- if (!is_equivalent) continue;
- return GetRef<Call>(candidate);
}
}
- expr_map_[new_call->op].push_back(new_call);
+ expr_map_[new_tuple_item->tuple].push_back(new_expr);
return new_expr;
}
- std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
+ std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
runtime::TypedPackedFunc<bool(Expr)> fskip_;
};
z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
assert tvm.ir.structural_equal(z, expected())
+def test_tuple_get_time():
+ def before():
+ x = relay.var('x', shape=(1, 16, 1, 1))
+ var = relay.var('var', shape=(16,))
+ mean = relay.var('mean', shape=(16,))
+ beta = relay.var('beta', shape=(16,))
+ gamma = relay.var('gamma', shape=(16,))
+ BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+ T1 = BN[0]
+ T2 = BN[0]
+ add = T1 + T2
+ f = relay.Function([x, var, mean, beta, gamma], add)
+ return f
+
+ def expected():
+ x = relay.var('x', shape=(1, 16, 1, 1))
+ var = relay.var('var', shape=(16,))
+ mean = relay.var('mean', shape=(16,))
+ beta = relay.var('beta', shape=(16,))
+ gamma = relay.var('gamma', shape=(16,))
+ BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+ T1 = BN[0]
+ add = T1 + T1
+ f = relay.Function([x, var, mean, beta, gamma], add)
+ return run_opt_pass(f, transform.InferType())
+
+ z = before()
+ z = run_opt_pass(z, transform.EliminateCommonSubexpr())
+ assert tvm.ir.structural_equal(z, expected())
if __name__ == "__main__":
test_simple()