[Relay]Port eliminate_common_subexpr to non-recursive form (#6134)
authorZheng Jiang <jiangzbox@gmail.com>
Fri, 24 Jul 2020 16:09:50 +0000 (00:09 +0800)
committerGitHub <noreply@github.com>
Fri, 24 Jul 2020 16:09:50 +0000 (09:09 -0700)
Co-authored-by: Zheng Jiang <zhejiang@amazon.com>
src/relay/transforms/eliminate_common_subexpr.cc

index dc3f77e..92cc64d 100644 (file)
 namespace tvm {
 namespace relay {
 
-class CommonSubexprEliminator : public ExprMutator {
+class CommonSubexprEliminator : public MixedModeMutator {
  public:
   explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
 
-  Expr VisitExpr_(const CallNode* call) final {
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
     static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
-    Expr new_expr = ExprMutator::VisitExpr_(call);
+    Expr new_expr = post;
     const CallNode* new_call = new_expr.as<CallNode>();
     CHECK(new_call);
     const OpNode* op = new_call->op.as<OpNode>();
@@ -80,8 +80,8 @@ class CommonSubexprEliminator : public ExprMutator {
     return new_expr;
   }
 
-  Expr VisitExpr_(const TupleGetItemNode* op) final {
-    Expr new_expr = ExprMutator::VisitExpr_(op);
+  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+    Expr new_expr = post;
     const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
     CHECK(new_tuple_item);