[BYOC] Enhance partitioning and external codegen (#5310)
authorZhi <5145158+zhiics@users.noreply.github.com>
Mon, 13 Apr 2020 21:06:02 +0000 (14:06 -0700)
committerGitHub <noreply@github.com>
Mon, 13 Apr 2020 21:06:02 +0000 (14:06 -0700)
* Remove duplicated output args

* address comment

* fix codegen c

* improve comment

* VisitExprDefault_

* deduce type

src/relay/backend/contrib/codegen_c/codegen.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/backend/contrib/dnnl/codegen.cc
src/relay/transforms/partition_graph.cc
tests/python/relay/test_pass_partition_graph.py

index 500e0dc..fc93b73 100644 (file)
@@ -40,35 +40,39 @@ using namespace backend;
  * purpose. Only several binary options are covered. Users
  * may need to extend them to cover more operators.
  */
-class CodegenC : public ExprVisitor, public CodegenCBase {
+class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
+                 public CodegenCBase {
  public:
   explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
 
-  void VisitExpr_(const VarNode* node) final {
+  std::vector<Output> VisitExpr(const Expr& expr) final {
+    if (visited_.count(expr)) return visited_.at(expr);
+    std::vector<Output> output = ExprFunctor::VisitExpr(expr);
+    visited_[expr] = output;
+    return output;
+  }
+
+  std::vector<Output> VisitExprDefault_(const Object* op) final {
+    LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
+    return {};
+  }
+
+  std::vector<Output> VisitExpr_(const VarNode* node) final {
     ext_func_args_.push_back(GetRef<Var>(node));
-    out_.clear();
     Output output;
     output.name = node->name_hint();
-    out_.push_back(output);
+    return {output};
   }
 
-  void VisitExpr_(const ConstantNode* cn) final {
-    Constant constant = GetRef<Constant>(cn);
-    if (visited_.count(constant)) {
-      // Note this is for demostration purpose. ConstantNode doesn't necessarily
-      // belong to calls. We need to revisit this when tuples come into play.
-      out_.push_back(visited_[constant]);
-      return;
-    }
+  std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
+    // Note this is for demonstration purpose. ConstantNode doesn't necessarily
+    // belong to calls. We need to revisit this when tuples come into play.
 
     std::ostringstream decl_stream;
     std::ostringstream buf_stream;
 
-    out_.clear();
     Output output;
     output.name = "const_" + std::to_string(const_idx_++);
-    out_.push_back(output);
-    visited_[constant] = output;
 
     runtime::NDArray array = cn->data;
     const auto& shape = array.Shape();
@@ -99,9 +103,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
     }
     buf_stream << "};";
     ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
+
+    return {output};
   }
 
-  void VisitExpr_(const CallNode* call) final {
+  std::vector<Output> VisitExpr_(const CallNode* call) final {
     std::ostringstream macro_stream;
     std::ostringstream decl_stream;
     std::ostringstream buf_stream;
@@ -138,8 +144,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
     bool first = true;
     decl_stream << func_name << "(";
     for (size_t i = 0; i < call->args.size(); ++i) {
-      VisitExpr(call->args[i]);
-      for (auto out : out_) {
+      auto res = VisitExpr(call->args[i]);
+      for (auto out : res) {
         if (!first) {
           decl_stream << ", ";
         }
@@ -162,13 +168,14 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
     ext_func_body.push_back(decl_stream.str());
 
     // Update output buffer
-    out_.clear();
+    // Note C codegen only handles TensorType. Therefore, we don't flatten
+    // tuples and only return a single vaule.
     Output output;
     output.name = out;
     output.dtype = dtype;
     output.need_copy = true;
     output.size = out_size;
-    out_.push_back(output);
+    return {output};
   }
 
   /*!
@@ -176,12 +183,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
    *
    * \return The emitted code.
    */
-  std::string JIT() {
+  std::string JIT(const std::vector<Output>& out) {
     // Write function macros
     for (auto decl : func_decl_) {
       code_stream_ << decl << "\n";
     }
-    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
+    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
   }
 
  private:
@@ -202,9 +209,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
   /*! \brief The declaration statements of buffers. */
   std::vector<std::string> buf_decl_;
   /*! \brief The name and index pairs for output. */
-  std::vector<Output> out_;
-  /*! \brief The cached expressions. */
-  std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
+  std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
 };
 
 class CSourceCodegen : public CSourceModuleCodegenBase {
@@ -216,8 +221,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
     auto sid = GetExtSymbol(func);
 
     CodegenC builder(sid);
-    builder.VisitExpr(func->body);
-    code_stream_ << builder.JIT();
+    auto out = builder.VisitExpr(func->body);
+    code_stream_ << builder.JIT(out);
   }
 
   runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
index 7dfa4ba..9226386 100644 (file)
@@ -165,9 +165,11 @@ class CodegenCBase {
   /*!
    * \brief Emit the code for external runtime.
    *
+   * \param out The outputs.
+   *
    * \return The code string.
    */
-  virtual std::string JIT() = 0;
+  virtual std::string JIT(const std::vector<Output>& out) = 0;
 
   /*!
    * \brief A common interface that is used by various external runtime to
index 7f3aabf..48652fc 100644 (file)
@@ -128,42 +128,43 @@ std::vector<std::string> Add(const CallNode* call) {
 
 // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
 // all utilities and make a base class for users to implement.
-class CodegenDNNL : public ExprVisitor, public CodegenCBase {
+class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
+                    public CodegenCBase {
  public:
   explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
 
-  void VisitExpr_(const VarNode* node) final {
+  std::vector<Output> VisitExpr(const Expr& expr) final {
+    if (visited_.count(expr)) return visited_.at(expr);
+    std::vector<Output> output = ExprFunctor::VisitExpr(expr);
+    visited_[expr] = output;
+    return output;
+  }
+
+  std::vector<Output> VisitExprDefault_(const Object* op) final {
+    LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
+    return {};
+  }
+
+  std::vector<Output> VisitExpr_(const VarNode* node) final {
     ext_func_args_.push_back(GetRef<Var>(node));
-    out_.clear();
     Output output;
     output.name = node->name_hint();
-    out_.push_back(output);
+    return {output};
   }
 
-  void VisitExpr_(const TupleGetItemNode* op) final {
-    VisitExpr(op->tuple);
-    CHECK(out_.size() > static_cast<size_t>(op->index));
+  std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
+    auto res = VisitExpr(op->tuple);
+    CHECK_GT(res.size(), static_cast<size_t>(op->index));
 
     // Only keep the item we want for the child node.
     // FIXME(@comaniac): The other items should still be requried for the primary outputs.
-    auto item = out_[op->index];
-    out_.clear();
-    out_.push_back(item);
+    return {res[op->index]};
   }
 
-  void VisitExpr_(const ConstantNode* cn) final {
-    Constant constant = GetRef<Constant>(cn);
-    if (visited_.count(constant)) {
-      out_.push_back(visited_[constant]);
-      return;
-    }
-
-    out_.clear();
+  std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
     Output output;
     output.name = "const_" + std::to_string(const_idx_++);
     output.dtype = "float";
-    out_.push_back(output);
-    visited_[constant] = output;
 
     runtime::NDArray array = cn->data;
 
@@ -176,16 +177,23 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
     CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
 
     std::ostringstream buf_stream;
-    buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
     const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
-    for (int64_t i = 0; i < num_elems; i++) {
-      buf_stream << "  " << output.name << "[" << i << "] = " << ptr[i] << ";\n";
+
+    // Allocate large arrays on the static section to avoid stakc overflow.
+    // Note that this would probably increase compilation time as the source
+    // file could be really large.
+    buf_stream << "static float " << output.name << "[" << num_elems <<"] = {";
+    for (int64_t i = 0; i < num_elems - 1; i++) {
+      buf_stream << ptr[i] << ",";
     }
+    if (num_elems > 0) buf_stream << ptr[num_elems - 1];
+    buf_stream << "};\n";
 
     ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
+    return {output};
   }
 
-  void VisitExpr_(const CallNode* call) final {
+  std::vector<Output> VisitExpr_(const CallNode* call) final {
     GenerateBodyOutput ret;
     if (const auto* func = call->op.as<FunctionNode>()) {
       ret = GenerateCompositeFunctionCall(func, call);
@@ -193,16 +201,13 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
       ret = GenerateOpCall(call);
     }
 
-    out_.clear();
-    for (size_t i = 0; i < ret.outputs.size(); ++i) {
-      buf_decl_.push_back(ret.buffers[i]);
-      out_.push_back(ret.outputs[i]);
-    }
+    buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end());
     ext_func_body.push_back(ret.decl);
+    return ret.outputs;
   }
 
-  std::string JIT(void) {
-    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
+  std::string JIT(const std::vector<Output>& out) {
+    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
   }
 
  private:
@@ -215,8 +220,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   std::vector<std::string> GetArgumentNames(const CallNode* call) {
     std::vector<std::string> arg_names;
     for (size_t i = 0; i < call->args.size(); ++i) {
-      VisitExpr(call->args[i]);
-      for (auto out : out_) {
+      auto res = VisitExpr(call->args[i]);
+      for (const auto& out : res) {
         arg_names.push_back(out.name);
       }
     }
@@ -331,17 +336,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
    */
   int buf_idx_{0};
   /*! \brief The index of global constants. */
-  int const_idx_ = 0;
+  int const_idx_{0};
   /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
   Array<Var> ext_func_args_;
   /*! \brief statement of the function that will be compiled using DNNL kernels. */
   std::vector<std::string> ext_func_body;
   /*! \brief The declaration of intermeidate buffers. */
   std::vector<std::string> buf_decl_;
-  /*! \brief The name of the the outputs. */
-  std::vector<Output> out_;
   /*! \brief The cached expressions. */
-  std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
+  std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
 };
 
 /*!
@@ -361,8 +364,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
     auto sid = GetExtSymbol(func);
 
     CodegenDNNL builder(sid);
-    builder.VisitExpr(func->body);
-    code_stream_ << builder.JIT();
+    auto out = builder.VisitExpr(func->body);
+    code_stream_ << builder.JIT(out);
   }
 
   /*!
index fa9c8c4..c8367fb 100644 (file)
@@ -148,25 +148,42 @@ class Partitioner : public ExprMutator {
       CHECK_EQ(call->args.size(), 1U);
 
       // Traverse the rest graph.
-      auto input_expr = VisitExpr(call->args[0]);
+      Expr parent = call->args[0];
+      auto input_expr = VisitExpr(parent);
+
+      // Backtrace the parent to find the first ancestor node that is not a begin or end op
+      while (const auto* parent_call = parent.as<CallNode>()) {
+        if (parent_call->op == compiler_begin_op ||
+            parent_call->op == compiler_end_op) {
+          parent = parent_call->args[0];
+        } else {
+          break;
+        }
+      }
 
       AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
       int index = GetArgIdx(sg, GetRef<Call>(call));
       CHECK_NE(index, -1);
-      // The type of the created variable is the same as the compiler_begin
-      // node.
-      std::string target = call->attrs.as<CompilerAttrs>()->compiler;
-      std::string varname =
-          target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
-      auto var = Var(varname, GetRef<Call>(call)->checked_type_);
-
-      auto cand = std::make_pair(var, input_expr);
-      if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
-          region_args[sg].end()) {
-        region_args[sg].push_back(cand);
-      }
 
-      return std::move(var);
+      if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
+        return shared_output_[parent][sg];
+      } else {
+        // The type of the created variable is the same as the compiler_begin
+        // node.
+        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+        std::string varname =
+            target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
+        auto var = Var(varname, GetRef<Call>(call)->checked_type_);
+
+        std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
+
+        if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
+            region_args[sg].end()) {
+          region_args[sg].push_back(cand);
+        }
+        shared_output_[parent][sg] = var;
+        return std::move(var);
+      }
     } else {
       CHECK_EQ(call->op, compiler_end_op);
       // The annotation node is inserted on edge so it must have only one
@@ -474,6 +491,12 @@ class Partitioner : public ExprMutator {
    * belongs to
    */
   std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
+
+  /*!\brief Cache the output that is shared by different nodes. */
+  using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
+  std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
+
+  /*!\brief The IRModule used for partitioning. */
   IRModule module_;
 };
 
index 1d0cc5b..5148d4e 100644 (file)
@@ -300,6 +300,14 @@ def test_extern_ccompiler_single_op():
     check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
 
 
+def set_func_attr(func, compile_name, symbol_name):
+    func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+    func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+    func = func.with_attr("Compiler", compile_name)
+    func = func.with_attr("global_symbol", symbol_name)
+    return func
+
+
 def test_extern_ccompiler_default_ops():
     def expected():
         mod = tvm.IRModule()
@@ -310,10 +318,7 @@ def test_extern_ccompiler_default_ops():
         add = x0 + y0
         # Function that uses C compiler
         func = relay.Function([x0, y0], add)
-        func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", "ccompiler")
-        func = func.with_attr("global_symbol", "ccompiler_0")
+        func = set_func_attr(func, "ccompiler", "ccompiler_0")
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [x, y])
@@ -380,32 +385,28 @@ def test_extern_dnnl():
 
     def expected():
         data0 = relay.var("data", shape=(ishape), dtype=dtype)
-        input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
-        input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
+        input0 = relay.var("input", shape=(w1shape), dtype=dtype)
         depthwise_conv2d_1 = relay.nn.conv2d(data0,
                                              input0,
                                              kernel_size=(3, 3),
                                              padding=(1, 1),
                                              groups=32)
         depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                             input1,
+                                             input0,
                                              kernel_size=(3, 3),
                                              padding=(1, 1),
                                              groups=32)
         out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
 
-        func = relay.Function([data0, input0, input1], out)
-        func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", "dnnl")
-        func = func.with_attr("global_symbol", "dnnl_0")
+        func = relay.Function([data0, input0], out)
+        func = set_func_attr(func, "dnnl", "dnnl_0")
         glb_var = relay.GlobalVar("dnnl_0")
         mod = tvm.IRModule()
         mod[glb_var] = func
 
         data = relay.var("data", shape=(ishape), dtype=dtype)
         weight = relay.var("input", shape=(w1shape), dtype=dtype)
-        main_f = relay.Function([data, weight], glb_var(data, weight, weight))
+        main_f = relay.Function([data, weight], glb_var(data, weight))
         mod["main"] = main_f
 
         return mod
@@ -444,7 +445,7 @@ def test_extern_dnnl():
     check_result(mod, {"data": i_data, "weight1": w1_data},
                  (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
 
-@pytest.mark.skip(reason="fix constant node before opening this case")
+
 def test_extern_dnnl_mobilenet():
     if not tvm.get_global_func("relay.ext.dnnl", True):
         print("skip because DNNL codegen is not available")
@@ -521,10 +522,7 @@ def test_function_lifting():
         bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
         func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
                                bn.astuple())
-        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler", "test_compiler")
-        func0 = func0.with_attr("global_symbol", "test_compiler_0")
+        func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
         gv0 = relay.GlobalVar("test_compiler_0")
         mod[gv0] = func0
 
@@ -538,10 +536,7 @@ def test_function_lifting():
             channels=16,
             padding=(1, 1))
         func1 = relay.Function([data1, weight1], conv)
-        func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Compiler", "test_compiler")
-        func1 = func1.with_attr("global_symbol", "test_compiler_1")
+        func1 = set_func_attr(func1, "test_compiler", "test_compiler_1")
         gv1 = relay.GlobalVar("test_compiler_1")
         mod[gv1] = func1
 
@@ -610,10 +605,7 @@ def test_function_lifting_inline():
         bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
         func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
                                bn.astuple())
-        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler", "test_compiler")
-        func0 = func0.with_attr("global_symbol", "test_compiler_0")
+        func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
 
         # main function
         data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
@@ -645,10 +637,7 @@ def test_constant_propagation():
         add = x0 + y0
         # Function that uses C compiler
         func = relay.Function([y0], add)
-        func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", "ccompiler")
-        func = func.with_attr("global_symbol", "ccompiler_0")
+        func = set_func_attr(func, "ccompiler", "ccompiler_0")
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [y])
@@ -745,10 +734,7 @@ def test_multiple_outputs():
 
         func0 = relay.Function([data, weight, bn_gamma, bn_beta,
                                 bn_mean, bn_var], tuple_o)
-        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler", "test_target")
-        func0 = func0.with_attr("global_symbol", "test_target_2")
+        func0 = set_func_attr(func0, "test_target", "test_target_2")
         gv0 = relay.GlobalVar("test_target_2")
         mod[gv0] = func0
 
@@ -810,11 +796,7 @@ def test_mixed_single_multiple_outputs():
         f1_O_2 = relay.nn.relu(f1_O_1)
         f1_out = relay.Tuple((f1_O_2, f1_O_1))
         func1 = relay.Function([f1_cb1], f1_out)
-
-        func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Compiler", "test_target")
-        func1 = func1.with_attr("global_symbol", "test_target_1")
+        func1 = set_func_attr(func1, "test_target", "test_target_1")
         gv1 = relay.GlobalVar("test_target_1")
         mod[gv1] = func1
 
@@ -823,11 +805,7 @@ def test_mixed_single_multiple_outputs():
         f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
         f2_O_3 = relay.add(f2_cb3, f2_cb4)
         func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
-
-        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler", "test_target")
-        func0 = func0.with_attr("global_symbol", "test_target_0")
+        func0 = set_func_attr(func0, "test_target", "test_target_0")
         gv0 = relay.GlobalVar("test_target_0")
         mod[gv0] = func0
 
@@ -967,10 +945,96 @@ def test_dnnl_fuse():
     ref_mod, ref_params = tvm.relay.testing.create_workload(net)
     test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
 
-    # exec test on mobilenet is not possible due to manually inlined constants
-    # mod, params = relay.testing.mobilenet.get_workload()
-    # ref_mod, ref_params = relay.testing.mobilenet.get_workload()
-    # test_exec(mod, params, ref_mod, ref_params, (1, 1000))
+    mod, params = relay.testing.mobilenet.get_workload()
+    ref_mod, ref_params = relay.testing.mobilenet.get_workload()
+    test_exec(mod, params, ref_mod, ref_params, (1, 1000))
+
+
+def test_multiple_use_of_an_output():
+    def expected_same_output_region():
+        mod = tvm.IRModule()
+        x = relay.var("x", shape=(8, 8))
+        y = relay.var("y", shape=(8, 8))
+        z = relay.var("z", shape=(8, 8))
+        x0 = relay.var("x0", shape=(8, 8))
+        y0 = relay.var("y0", shape=(8, 8))
+        log = relay.log(x0)
+        sub = x0 - y0
+        mul = log * sub
+        # The partitioned graph contains log, subtract, and multiply
+        func = relay.Function([x0, y0], mul)
+        func = set_func_attr(func, "ccompiler", "ccompiler_0")
+        glb_0 = relay.GlobalVar("ccompiler_0")
+        mod[glb_0] = func
+        add = x + y
+        call = relay.Call(glb_0, [add, z])
+        main = relay.Function([x, y, z], call)
+        mod["main"] = main
+        return mod
+
+    def expected_different_output_region():
+        mod = tvm.IRModule()
+        x = relay.var("x", shape=(8, 8))
+        y = relay.var("y", shape=(8, 8))
+        z = relay.var("z", shape=(8, 8))
+
+        # The partitioned graph contains log
+        i0 = relay.var("i0", shape=(8, 8))
+        log = relay.log(i0)
+        func = relay.Function([i0], log)
+        func = set_func_attr(func, "ccompiler", "ccompiler_0")
+        glb_0 = relay.GlobalVar("ccompiler_0")
+        mod[glb_0] = func
+
+        # The partitioned graph contains subtract
+        x0 = relay.var("x0", shape=(8, 8))
+        y0 = relay.var("y0", shape=(8, 8))
+        sub = x0 - y0
+        func = relay.Function([x0, y0], sub)
+        func = set_func_attr(func, "ccompiler", "ccompiler_1")
+        glb_1 = relay.GlobalVar("ccompiler_1")
+        mod[glb_1] = func
+
+        add = x + y
+        call_log = relay.Call(glb_0, [add])
+        call_sub = relay.Call(glb_1, [add, z])
+        main = relay.Function([x, y, z], call_log * call_sub)
+        mod["main"] = main
+        return mod
+
+    def get_mod():
+        x = relay.var("x", shape=(8, 8))
+        y = relay.var("y", shape=(8, 8))
+        z = relay.var("z", shape=(8, 8))
+        add = x + y
+        sub = add - z
+        log = relay.log(add)
+        sub1 = log * sub
+        f = relay.Function([x, y, z], sub1)
+        mod = tvm.IRModule()
+        mod["main"] = f
+        return mod
+
+    def test_same_output_region():
+        mod = get_mod()
+        mod = WhiteListAnnotator(["subtract", "log", "multiply"], "ccompiler")(mod)
+        mod = transform.MergeCompilerRegions()(mod)
+        mod = transform.PartitionGraph()(mod)
+
+        expected_mod = expected_same_output_region()
+        assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
+
+    def test_different_output_region():
+        mod = get_mod()
+        mod = WhiteListAnnotator(["subtract", "log"], "ccompiler")(mod)
+        mod = transform.MergeCompilerRegions()(mod)
+        mod = transform.PartitionGraph()(mod)
+
+        expected_mod = expected_different_output_region()
+        assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
+
+    test_same_output_region()
+    test_different_output_region()
 
 
 if __name__ == "__main__":
@@ -979,11 +1043,11 @@ if __name__ == "__main__":
     test_extern_ccompiler_default_ops()
     test_extern_ccompiler()
     test_extern_dnnl()
-    # TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
-    #test_extern_dnnl_mobilenet()
+    test_extern_dnnl_mobilenet()
     test_function_lifting()
     test_function_lifting_inline()
     test_constant_propagation()
     test_multiple_outputs()
     test_mixed_single_multiple_outputs()
     test_dnnl_fuse()
+    test_multiple_use_of_an_output()