[relay] Relay annotation and partitioning for external compilers (#4570)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 14 Jan 2020 19:40:00 +0000 (11:40 -0800)
committerGitHub <noreply@github.com>
Tue, 14 Jan 2020 19:40:00 +0000 (11:40 -0800)
* [relay] Relay annotation and partitioning for codegen

* Add fusion unit test

* fix comments

* Update include/tvm/relay/attrs/annotation.h

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>
* rebase

* remove annotation helper

* rebase again

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: 雾雨魔理沙 <lolisa@marisa.moe>
include/tvm/relay/attrs/annotation.h
include/tvm/relay/op_attr_types.h
include/tvm/relay/transform.h
python/tvm/relay/op/annotation/annotation.py
python/tvm/relay/transform.py
src/relay/backend/contrib/dnnl/codegen.cc
src/relay/op/annotation/annotation.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/partition_graph.cc [new file with mode: 0644]
tests/python/relay/test_pass_partition_graph.py [new file with mode: 0644]

index fd21db5..4481d2a 100644 (file)
@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
   }
 };
 
+/*!
+ * \brief Options for the operators used to annotate a compiler.
+ */
+struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
+  /*! \brief A 3rd party compiler for code generation. */
+  std::string compiler;
+
+  TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
+    TVM_ATTR_FIELD(compiler)
+      .describe("A 3rd party compiler used for code generation.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_ANNOTATION_H_
index 9cfa755..b6221e0 100644 (file)
@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
  *  operator with other expressions. This function will be invoked
  *  in AlterOpLayout pass.
  * \param attrs The attribute of the original node.
- * \param inputs The input symbols of the original node.
+ * \param args The input symbols of the original node.
  * \param tinfos An array of placeholders, use for getting the inferred shape
  *               and dtype of the inputs.
  * \return new_expr The modified expression.
@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
  * \brief Legalizes an expression with another expression. This function will be
  *  invoked in Legalize pass. It is a target-dependent pass.
  * \param attrs The attribute of the original node.
- * \param inputs The input symbols of the original node.
- * \param tinfos An array of placeholders, use for getting the inferred shape
+ * \param args The input symbols of the original node.
+ * \param arg_types An array of placeholders, use for getting the inferred shape
  *               and dtype of the inputs.
  * \return new_expr The modified expression.
  */
index 294ffb9..58cfbfc 100644 (file)
@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
  */
 TVM_DLL Pass PrintIR(bool show_meta_data = true);
 
+/*!
+ * \brief Partition a Relay program into regions that can be executed on
+ * different backends.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass PartitionGraph();
+
 }  // namespace transform
 
 /*!
index 2b9d4bc..9363925 100644 (file)
@@ -62,6 +62,7 @@ def stop_fusion(data):
     """
     return _make.stop_fusion(data)
 
+
 def checkpoint(data):
     """Annotate an expression to be a checkpoint for the checkpointing memory optimization.
 
@@ -78,3 +79,43 @@ def checkpoint(data):
     return _make.checkpoint(data)
 
 register_schedule("annotation.checkpoint", schedule_injective)
+
+
+def compiler_begin(data, compiler):
+    """Annotate an expression to indicate that it is the beginning of
+    a regeion that will be handled by the given compiler.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The expression to be annotated.
+
+    compiler : Str
+        The compiler used to generate code of the annotated region.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The annotated expression.
+    """
+    return _make.compiler_begin(data, compiler)
+
+
+def compiler_end(data, compiler):
+    """Annotate an expression to indicate that it is the end of a region that
+    is handled by the provided compiler.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The expression to be annotated.
+
+    compiler : Str
+        The compiler used to generate code of the annotated region.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The annotated expression.
+    """
+    return _make.compiler_end(data, compiler)
index 1f91272..c4fbde6 100644 (file)
@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
     return _transform.PrintIR(show_meta_data)
 
 
+def PartitionGraph():
+    """Partition a Relay program into regions that can be executed on different
+    backends.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that partitions the Relay program.
+    """
+    return _transform.PartitionGraph()
+
+
 def gradient(expr, mod=None, mode='higher_order'):
     """
     Transform the input function,
index 9c24944..fbe047d 100644 (file)
@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
 
     if (ref->IsInstance<FunctionNode>()) {
       GenDNNLFunc(Downcast<Function>(ref));
-    } else if (ref->IsInstance<relay::ModuleNode>()) {
-      relay::Module mod = Downcast<relay::Module>(ref);
+    } else if (ref->IsInstance<IRModuleNode>()) {
+      IRModule mod = Downcast<IRModule>(ref);
       for (const auto& it : mod->functions) {
         GenDNNLFunc(Downcast<Function>(it.second));
       }
index efcb383..3d03f88 100644 (file)
@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
                          return outputs;
                        });
 
+RELAY_REGISTER_OP("annotation.compiler_begin")
+.describe(R"code(
+Beginning of a region that is handled by a given compiler.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_support_level(10)
+.add_type_rel("Identity", IdentityRel)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                               ElemwiseArbitraryLayout)
+.set_attr<FTVMCompute>("FTVMCompute",
+                       [](const Attrs& attrs, const Array<Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                         return {topi::identity(inputs[0])};
+                       });
+
+TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
+.set_body_typed([](Expr expr, std::string compiler) {
+  auto attrs = make_object<CompilerAttrs>();
+  attrs->compiler = compiler;
+  static const Op& op = Op::Get("annotation.compiler_begin");
+  return CallNode::make(op, {expr}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("annotation.compiler_end")
+.describe(R"code(
+End of a region that is handled by a given compiler.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_support_level(10)
+.add_type_rel("Identity", IdentityRel)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                               ElemwiseArbitraryLayout)
+.set_attr<FTVMCompute>("FTVMCompute",
+                       [](const Attrs& attrs, const Array<Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                         return {topi::identity(inputs[0])};
+                       });
+
+TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
+.set_body_typed([](Expr expr, std::string compiler) {
+  auto attrs = make_object<CompilerAttrs>();
+  attrs->compiler = compiler;
+  static const Op& op = Op::Get("annotation.compiler_end");
+  return CallNode::make(op, {expr}, Attrs(attrs), {});
+});
+
 }  // namespace relay
 }  // namespace tvm
index bf38a48..e18dbc2 100644 (file)
@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     // Finally if the operator position is not a call node we will
     // need to call Update, as it may be an arbitrary expression.
     OpPatternKind op_pattern = kOpaque;
-    const OpNode* opnode = call->op.as<OpNode>();
-    if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
+    if (const OpNode* opnode = call->op.as<OpNode>()) {
       op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
     } else {
       this->Update(call->op, node, kOpaque);
diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc
new file mode 100644 (file)
index 0000000..634affe
--- /dev/null
@@ -0,0 +1,386 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file src/relay/pass/partition_graph.cc
+ *
+ * \brief Partition an input function into multiple functions according based
+ * on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
+ * These nodes are used as boundaries to partition the Relay function into
+ * multiple regions that can be offloaded to different accelerators/backends.
+ *
+ * Each of these paritioned functions, a.k.a subgraphs, will be viewed as
+ * external functions, and they will use the provided compiler for codegen.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace partitioning {
+
+// Cache compiler_begin and compiler_end annotation ops for equivalence check to
+// reduce registry lookup overhead.
+static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+/*!
+ * \brief The subgraph properties for partitioning.
+ */
+struct Subgraph {
+  /*! \brief The subgraph ID. */
+  int id;
+
+  /*! \brief The input arguments of this subgraph. */
+  std::vector<std::pair<Var, Expr>> args;
+
+  /*! \brief Nodes in this subgraph. */
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+};
+
+/*!
+ * \brief The checker that verifies if a Relay program is annotated correctly
+ * for partitioning.
+ */
+class AnnotationChecker : public ExprVisitor {
+ public:
+  bool Check() {
+    if (!found_start_ && !found_end_) {
+      LOG(WARNING) << "No compiler annotation found";
+    } else if (!found_start_) {
+      LOG(ERROR) << "compiler_begin annotation is missing";
+      return false;
+    } else if (!found_end_) {
+      LOG(ERROR) << "compiler_end annotation is missing";
+      return false;
+    }
+    return true;
+  }
+
+  void VisitExpr_(const CallNode* call) final {
+    auto op_node = call->op.as<OpNode>();
+    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+      return;
+    } else if (call->op == compiler_begin_op) {
+      found_start_ = true;
+    } else if (call->op == compiler_end_op) {
+      found_end_ = true;
+    }
+  }
+
+ private:
+  bool found_start_{false};
+  bool found_end_{false};
+};
+
+/*! \brief This class partitions the expr labeled with begin and end annoations
+ * into function containing multiple regions. Each region is labeled with
+ * a compiler attribute so that it will be handled by any compilers that are not
+ * in the TVM stack.
+ *
+ * TODO(@zhiics) This following algorithm is not adequate to handle all cases,
+ * i.e. multiple `compiler_end` nodes.
+ */
+class Partitioner : public ExprMutator {
+ public:
+  std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
+    for (auto candidate : this->subgraphs_) {
+      if (candidate->nodes.find(node) != candidate->nodes.end()) {
+        return candidate;
+      }
+    }
+    return nullptr;
+  }
+
+  void MergeSubgraph(std::shared_ptr<Subgraph> subgraph1,
+                     std::shared_ptr<Subgraph> subgraph2) {
+    if (subgraph1 == subgraph2) {
+      return;
+    }
+
+    // Merge subgraph 2 to subgraph 1 and erase subgraph 2.
+    subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end());
+    for (auto arg : subgraph2->args) {
+      subgraph1->args.push_back(arg);
+    }
+    this->subgraphs_.erase(subgraph2);
+  }
+
+  void AddToSubgraph(std::shared_ptr<Subgraph> subgraph, const Expr expr) {
+    auto subgraph2 = GetSubgraph(expr);
+    if (subgraph2) {
+      MergeSubgraph(subgraph, subgraph2);
+    } else {
+      subgraph->nodes.insert(expr);
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+    auto op_node = call->op.as<OpNode>();
+
+    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+      // Propogate subgraph to arguments
+      auto subgraph = GetSubgraph(GetRef<Call>(call));
+      if (subgraph) {
+        for (auto arg : call->args) {
+          AddToSubgraph(subgraph, arg);
+        }
+      }
+      return ExprMutator::VisitExpr_(call);
+    } else if (call->op == compiler_begin_op) {
+      // The annotation node is inserted on edge so it must have only one argument.
+      CHECK_EQ(call->args.size(), 1U);
+
+      // Traverse the rest graph.
+      auto input_expr = VisitExpr(call->args[0]);
+
+      // Replace the begin annotation with an external call input variable.
+      auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+      auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
+                               input_expr->checked_type_);
+
+      // Find the corresponding subgraph and add the argument.
+      auto subgraph = GetSubgraph(GetRef<Call>(call));
+      if (!subgraph) {
+        throw Error(ErrorBuilder()
+                    << "Cannot find the corresponding subgraph for start annotation:\n"
+                    << AsText(GetRef<Call>(call), false));
+      }
+      subgraph->args.push_back({var, input_expr});
+      return std::move(var);
+    } else {
+      CHECK_EQ(call->op, compiler_end_op);
+      // The annotation node is inserted on edge so it must have only one argument.
+      CHECK_EQ(call->args.size(), 1U);
+
+      auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+
+      // Check if the argument already belongs to an exist subgraph
+      auto subgraph = GetSubgraph(call->args[0]);
+      if (!subgraph) {
+        auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
+        subgraph = *ret.first;
+        subgraph->nodes.insert(call->args[0]);
+        subgraph->id = this->subgraph_id_++;
+      }
+      subgraph->nodes.insert(GetRef<Call>(call));
+
+      // Traverse subgraph inputs.
+      auto input = VisitExpr(call->args[0]);
+      Array<Var> params;
+      Array<Expr> args;
+
+      // The subgraph may be merged so we need to update it again.
+      subgraph = GetSubgraph(GetRef<Call>(call));
+      CHECK(subgraph);
+
+      for (auto pair : subgraph->args) {
+        params.push_back(pair.first);
+        args.push_back(pair.second);
+      }
+
+      auto subgraph_func =
+          FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
+
+      Expr arg0 = call->args[0];
+      std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
+      subgraph_func =
+          FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name));
+      subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
+      subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
+                                      tvm::ir::StringImmNode::make(compiler_attrs->compiler));
+      return CallNode::make(subgraph_func, args);
+    }
+  }
+
+  Expr VisitExpr_(const TupleNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<Tuple>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      for (auto field : op->fields) {
+        AddToSubgraph(subgraph, field);
+      }
+      Array<Expr> fields;
+      for (auto field : op->fields) {
+        fields.push_back(VisitExpr(field));
+      }
+      return TupleNode::make(fields);
+    }
+  }
+
+  Expr VisitExpr_(const TupleGetItemNode* g) final {
+    auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(g);
+    } else {
+      AddToSubgraph(subgraph, g->tuple);
+      auto t = VisitExpr(g->tuple);
+      return TupleGetItemNode::make(t, g->index);
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<Function>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      Array<Var> params;
+      for (auto param : op->params) {
+        AddToSubgraph(subgraph, param);
+      }
+      for (auto param : op->params) {
+        Var new_param = Downcast<Var>(VisitExpr(param));
+        params.push_back(new_param);
+      }
+      auto body = VisitExpr(op->body);
+      return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs);
+    }
+  }
+
+  Expr VisitExpr_(const LetNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<Let>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      AddToSubgraph(subgraph, op->var);
+      AddToSubgraph(subgraph, op->value);
+      AddToSubgraph(subgraph, op->body);
+      Var var = Downcast<Var>(VisitExpr(op->var));
+      auto value = VisitExpr(op->value);
+      auto body = VisitExpr(op->body);
+
+      return LetNode::make(var, value, body);
+    }
+  }
+
+  Expr VisitExpr_(const IfNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<If>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      AddToSubgraph(subgraph, op->cond);
+      AddToSubgraph(subgraph, op->true_branch);
+      AddToSubgraph(subgraph, op->false_branch);
+      auto guard = VisitExpr(op->cond);
+      auto true_b = VisitExpr(op->true_branch);
+      auto false_b = VisitExpr(op->false_branch);
+      return IfNode::make(guard, true_b, false_b);
+    }
+  }
+
+  Expr VisitExpr_(const RefCreateNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<RefCreate>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      AddToSubgraph(subgraph, op->value);
+      Expr value = VisitExpr(op->value);
+      return RefCreateNode::make(value);
+    }
+  }
+
+  Expr VisitExpr_(const RefReadNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<RefRead>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      AddToSubgraph(subgraph, op->ref);
+      Expr ref = VisitExpr(op->ref);
+      return RefReadNode::make(ref);
+    }
+  }
+
+  Expr VisitExpr_(const RefWriteNode* op) final {
+    auto subgraph = GetSubgraph(GetRef<RefWrite>(op));
+    if (!subgraph) {
+      return ExprMutator::VisitExpr_(op);
+    } else {
+      AddToSubgraph(subgraph, op->ref);
+      Expr ref = VisitExpr(op->ref);
+      Expr value = VisitExpr(op->value);
+      return RefWriteNode::make(ref, value);
+    }
+  }
+
+ private:
+  int var_id_{0};
+  int subgraph_id_{0};
+  std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
+};
+
+/*!
+ * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
+ * the same codegen backend. This reduces rounds trips between TVM and external
+ * backends. Likely we can borrow some ideas from operator fusion.
+ *
+ * For example, sg1 and sg2 should be combined if they belong to the same
+ * codegen tool in the following case.
+ *
+ *      op1
+ *     /   \
+ *   sg1   sg2
+ *
+ *       |
+ *      \|/
+ *
+ *      op1
+ *       |
+ *    sg1_sg2
+ *
+ * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
+ * inputs that obtained from the tuple.
+ */
+
+Expr PartitionGraph(const Expr& expr) {
+  Partitioner part;
+  return part.Mutate(expr);
+}
+
+}  // namespace partitioning
+
+namespace transform {
+
+Pass PartitionGraph() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(partitioning::PartitionGraph(f));
+      };
+  auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
+  return Sequential({partitioned, InferType()});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
+.set_body_typed(transform::PartitionGraph);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
new file mode 100644 (file)
index 0000000..4ffb373
--- /dev/null
@@ -0,0 +1,434 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for graph partitioning."""
+import os
+import sys
+import numpy as np
+import pytest
+
+import tvm
+import tvm.relay.testing
+import tvm.relay.transform as transform
+from tvm import relay
+from tvm.contrib import util
+from tvm.relay.annotation import compiler_begin, compiler_end
+from tvm.relay.expr_functor import ExprMutator
+
+# Leverage the pass manager to write a simple white list based annotator
+@transform.function_pass(opt_level=0)
+class WhiteListAnnotator:
+    def __init__(self, op_list, compiler):
+        assert isinstance(op_list, (list, tuple, set))
+        self.op_list = op_list
+        self.compiler = compiler
+
+    def transform_function(self, func, mod, ctx):
+
+        annotator = self
+        class Annotator(tvm.relay.ExprMutator):
+            def visit_call(self, call):
+                op_name = call.op.name
+                if op_name in annotator.op_list:
+                    new_args = []
+                    for arg in call.args:
+                        ann = compiler_begin(super().visit(arg),
+                                             annotator.compiler)
+                        new_args.append(ann)
+                    new_call = relay.Call(call.op, new_args, call.attrs,
+                                          call.type_args)
+                    return compiler_end(new_call, annotator.compiler)
+                else:
+                    return super().visit_call(call)
+        return Annotator().visit(func)
+
+
+class CcompilerAnnotator(ExprMutator):
+    """
+    A simple annotator that creates the following program:
+           |
+      -- begin --
+           |
+          add
+           |
+        subtract
+           |
+        multiply
+           |
+       -- end --
+           |
+    """
+
+    def __init__(self):
+        super(CcompilerAnnotator, self).__init__()
+        self.in_compiler = 0
+
+    def visit_call(self, call):
+        if call.op.name == "add":  # Annotate begin at args
+            if self.in_compiler == 1:
+                lhs = compiler_begin(super().visit(call.args[0]), "ccompiler")
+                rhs = compiler_begin(super().visit(call.args[1]), "ccompiler")
+                op = relay.add(lhs, rhs)
+                self.in_compiler = 2
+                return op
+        elif call.op.name == "subtract":
+            if self.in_compiler == 1:
+                lhs = super().visit(call.args[0])
+                rhs = super().visit(call.args[1])
+                if isinstance(lhs, relay.expr.Var):
+                    lhs = compiler_begin(lhs, "ccompiler")
+                if isinstance(rhs, relay.expr.Var):
+                    rhs = compiler_begin(rhs, "ccompiler")
+                return relay.subtract(lhs, rhs)
+        elif call.op.name == "multiply":  # Annotate end at output
+            self.in_compiler = 1
+            lhs = super().visit(call.args[0])
+            rhs = super().visit(call.args[1])
+            if isinstance(lhs, relay.expr.Var):
+                lhs = compiler_begin(lhs, "ccompiler")
+            if isinstance(rhs, relay.expr.Var):
+                rhs = compiler_begin(rhs, "ccompiler")
+            op = relay.multiply(lhs, rhs)
+            if self.in_compiler == 2:
+                op = compiler_end(op, "ccompiler")
+            self.in_compiler = 0
+            return op
+        return super().visit_call(call)
+
+
+class WholeGraphAnnotator(ExprMutator):
+    """
+    An annotator that creates a compiler for an entire graph.
+    """
+
+    def __init__(self, compiler):
+        super(WholeGraphAnnotator, self).__init__()
+        self.compiler = compiler
+        self.last_call = True
+
+    def visit_call(self, call):
+        curr_last = self.last_call
+        self.last_call = False
+
+        params = []
+        for arg in call.args:
+            param = super().visit(arg)
+            if isinstance(param, relay.expr.Var):
+                param = compiler_begin(param, self.compiler)
+            params.append(param)
+
+        new_call = relay.Call(call.op, params, call.attrs)
+        if curr_last:
+            new_call = compiler_end(new_call, self.compiler)
+        return new_call
+
+
+class MobileNetAnnotator(ExprMutator):
+    """
+    Annotate mobilenet until global_avg_pool.
+    """
+
+    def __init__(self, compiler):
+        super(MobileNetAnnotator, self).__init__()
+        self.compiler = compiler
+        self.compiler_open = False
+
+    def visit_call(self, call):
+
+        if call.op.name == 'nn.global_avg_pool2d':
+            self.compiler_open = True
+        compiler_open = self.compiler_open
+
+        params = []
+        for arg in call.args:
+            param = super().visit(arg)
+            if call.op.name == 'nn.global_avg_pool2d':
+                param = compiler_end(param, self.compiler)
+            if compiler_open and isinstance(param, relay.expr.Var):
+                param = compiler_begin(param, self.compiler)
+            params.append(param)
+
+        new_call = relay.Call(call.op, params, call.attrs)
+        return new_call
+
+
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+                 ctx=tvm.cpu(), params=None):
+    if sys.platform == "win32":
+        print("Skip test on Windows for now")
+        return
+
+    def update_lib(lib):
+        test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+        source_dir = os.path.join(test_dir, "..", "..", "..")
+        contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+        kwargs = {}
+        kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+        tmp_path = util.tempdir()
+        lib_name = 'lib.so'
+        lib_path = tmp_path.relpath(lib_name)
+        lib.export_library(lib_path, fcompile=False, **kwargs)
+        lib = tvm.module.load(lib_path)
+
+        return lib
+
+    def check_vm_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            exe = relay.vm.compile(mod, target=target, params=params)
+        code, lib = exe.save()
+        lib = update_lib(lib)
+        exe = relay.vm.Executable.load_exec(code, lib)
+        vm = relay.vm.VirtualMachine(exe)
+        vm.init(ctx)
+        out = vm.run(**map_inputs)
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    def check_graph_runtime_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            json, lib, param = relay.build(mod, target=target, params=params)
+        lib = update_lib(lib)
+        rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+        for name, data in map_inputs.items():
+            rt_mod.set_input(name, data)
+        rt_mod.set_input(**param)
+        rt_mod.run()
+        out = tvm.nd.empty(out_shape, ctx=ctx)
+        out = rt_mod.get_output(0, out)
+
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    check_vm_result()
+    check_graph_runtime_result()
+
+
+def test_multi_node_compiler():
+    x = relay.var('x', shape=(10, 10))
+    w0 = relay.var('w0', shape=(10, 10))
+    w1 = relay.var('w1', shape=(10, 10))
+    w2 = relay.var('w2', shape=(10, 10))
+    w3 = relay.var('w3', shape=(10, 10))
+    w4 = relay.var('w4', shape=(10, 10))
+    w5 = relay.var('w5', shape=(10, 10))
+    w6 = relay.var('w6', shape=(10, 10))
+    w7 = relay.var('w7', shape=(10, 10))
+
+    # C compiler
+    # FIXME: We generate two compilers for this case but they should be merged to one
+    # due to the common input (x).
+    z0 = relay.add(x, w0)
+    p0 = relay.subtract(z0, w1)
+    q0 = relay.multiply(p0, w2)
+
+    z1 = relay.add(x, w3)
+    p1 = relay.subtract(z1, w4)
+    q1 = relay.multiply(p1, w5)
+
+    # Other parts on TVM
+    z2 = relay.add(x, w6)
+    q2 = relay.subtract(z2, w7)
+
+    r = relay.concatenate((q0, q1, q2), axis=0)
+    f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
+    mod = relay.Module()
+    ann = CcompilerAnnotator()
+    mod["main"] = ann.visit(f)
+    mod = transform.PartitionGraph()(mod)
+    mod = transform.InferType()(mod)
+
+    x_data = np.random.rand(10, 10).astype('float32')
+    w_data = []
+    for _ in range(8):
+        w_data.append(np.random.rand(10, 10).astype('float32'))
+
+    map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
+    map_inputs["x"] = x_data
+    check_result(
+        mod, map_inputs, (30, 10),
+        np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                        ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                        x_data + w_data[6] - w_data[7]),
+                       axis=0))
+
+
+def test_extern_ccompiler_single_op():
+    @transform.function_pass(opt_level=0)
+    class MyAnnotator:
+        def transform_function(self, func, mod, ctx):
+            class Annotator(tvm.relay.ExprMutator):
+                def visit_call(self, call):
+                    new_args = []
+                    for arg in call.args:
+                        ann = compiler_begin(self.visit(arg), "ccompiler")
+                        new_args.append(ann)
+                    new_call = relay.Call(call.op, new_args)
+                    return compiler_end(new_call, "ccompiler")
+            return Annotator().visit(func)
+
+    x = relay.var('x', shape=(8, 8))
+    y = relay.var('y', shape=(8, 8))
+    z = x + y
+    f = relay.Function([x, y], z)
+    x_data = np.random.rand(8, 8).astype('float32')
+    y_data = np.random.rand(8, 8).astype('float32')
+    mod = relay.Module()
+    mod["main"] = f
+    mod = MyAnnotator()(mod)
+    mod = transform.PartitionGraph()(mod)
+
+    check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
+
+
+def test_extern_ccompiler_default_ops():
+    def expected():
+        x = relay.var("x", shape=(8, 8))
+        y = relay.var("y", shape=(8, 8))
+        x0 = relay.var("x0", shape=(8, 8))
+        y0 = relay.var("y0", shape=(8, 8))
+        add = x0 + y0
+        # Function that uses C compiler
+        func = relay.Function([x0, y0], add)
+        func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
+        func = func.set_attribute("Compiler",
+                                  tvm.expr.StringImm("ccompiler"))
+        func = func.set_attribute("ExternalSymbol",
+                                  tvm.expr.StringImm("ccompiler_0"))
+        add_call = relay.Call(func, [x, y])
+        # Function that uses default compiler. Ops are fused in this function.
+        p0 = relay.var("p0", shape=(8, 8))
+        log = relay.log(p0)
+        exp = relay.exp(p0)
+        concat = relay.concatenate([log, exp], axis=0)
+        fused_func = relay.Function([p0], concat)
+        fused_func = fused_func.set_attribute("Primitive",
+                                              tvm.expr.IntImm("int32", 1))
+        fused_call = relay.Call(fused_func, [add_call])
+        main = relay.Function([x, y], fused_call)
+        mod = relay.Module()
+        mod["main"] = main
+        return mod
+
+    x = relay.var("x", shape=(8, 8))
+    y = relay.var("y", shape=(8, 8))
+    add = x + y
+    log = relay.log(add)
+    exp = relay.exp(add)
+    concat = relay.concatenate([log, exp], axis=0)
+    f = relay.Function([x, y], concat)
+    mod = relay.Module()
+    mod["main"] = f
+    mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
+    mod = transform.PartitionGraph()(mod)
+
+    fused_mod = transform.FuseOps(2)(mod)
+    expected_mod = expected()
+    assert relay.alpha_equal(fused_mod, expected_mod)
+
+    x_data = np.random.rand(8, 8).astype('float32')
+    y_data = np.random.rand(8, 8).astype('float32')
+    np_add = x_data + y_data
+    res = np.concatenate([np.log(np_add), np.exp(np_add)])
+    check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
+
+
+def test_extern_ccompiler():
+    x = relay.var('x', shape=(2, 2))
+    y = relay.var('y', shape=(2, 2))
+    z = x + x
+    p = y * y
+    f = relay.Function([x, y], p - z)
+    x_data = np.random.rand(2, 2).astype('float32')
+    y_data = np.random.rand(2, 2).astype('float32')
+    mod = relay.Module()
+    mod["main"] = f
+    mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
+    mod = transform.PartitionGraph()(mod)
+
+    check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
+
+
+def test_extern_dnnl():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    dtype = 'float32'
+    ishape = (1, 32, 14, 14)
+    w1shape = (32, 1, 3, 3)
+    data = relay.var('data', shape=(ishape), dtype=dtype)
+    weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+    depthwise_conv2d_1 = relay.nn.conv2d(data,
+                                         weight1,
+                                         kernel_size=(3, 3),
+                                         padding=(1, 1),
+                                         groups=32)
+    depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+                                         weight1,
+                                         kernel_size=(3, 3),
+                                         padding=(1, 1),
+                                         groups=32)
+    out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+    f = relay.Function([data, weight1], out)
+
+    mod = relay.Module()
+    mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
+    mod = transform.PartitionGraph()(mod)
+
+    ref_mod = relay.Module()
+    ref_mod['main'] = f
+
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+    w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+    ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
+    ref_res = ref_ex.evaluate()(i_data, w1_data)
+    check_result(mod, {"data": i_data, "weight1": w1_data},
+                 (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+
+
+def test_extern_dnnl_mobilenet():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    dtype = 'float32'
+    ishape = (1, 3, 224, 224)
+    mod, params = relay.testing.mobilenet.get_workload(
+        batch_size=1, dtype='float32')
+
+    op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
+    mod = WhiteListAnnotator(op_list, "dnnl")(mod)
+    mod = transform.PartitionGraph()(mod)
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+
+    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
+                                                           dtype='float32')
+    ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+    ref_res = ref_ex.evaluate()(i_data, **params)
+
+    check_result(mod, {"data": i_data},
+                 (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+
+
+if __name__ == "__main__":
+    test_multi_node_compiler()
+    test_extern_ccompiler_single_op()
+    test_extern_ccompiler_default_ops()
+    test_extern_ccompiler()
+    test_extern_dnnl()
+    test_extern_dnnl_mobilenet()