--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file src/relay/pass/partition_graph.cc
+ *
+ * \brief Partition an input function into multiple functions according based
+ * on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
+ * These nodes are used as boundaries to partition the Relay function into
+ * multiple regions that can be offloaded to different accelerators/backends.
+ *
+ * Each of these paritioned functions, a.k.a subgraphs, will be viewed as
+ * external functions, and they will use the provided compiler for codegen.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace partitioning {
+
+// Cache compiler_begin and compiler_end annotation ops for equivalence check to
+// reduce registry lookup overhead.
+static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+/*!
+ * \brief The subgraph properties for partitioning.
+ */
+struct Subgraph {
+ /*! \brief The subgraph ID. */
+ int id;
+
+ /*! \brief The input arguments of this subgraph. */
+ std::vector<std::pair<Var, Expr>> args;
+
+ /*! \brief Nodes in this subgraph. */
+ std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+};
+
+/*!
+ * \brief The checker that verifies if a Relay program is annotated correctly
+ * for partitioning.
+ */
+class AnnotationChecker : public ExprVisitor {
+ public:
+ bool Check() {
+ if (!found_start_ && !found_end_) {
+ LOG(WARNING) << "No compiler annotation found";
+ } else if (!found_start_) {
+ LOG(ERROR) << "compiler_begin annotation is missing";
+ return false;
+ } else if (!found_end_) {
+ LOG(ERROR) << "compiler_end annotation is missing";
+ return false;
+ }
+ return true;
+ }
+
+ void VisitExpr_(const CallNode* call) final {
+ auto op_node = call->op.as<OpNode>();
+ if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+ return;
+ } else if (call->op == compiler_begin_op) {
+ found_start_ = true;
+ } else if (call->op == compiler_end_op) {
+ found_end_ = true;
+ }
+ }
+
+ private:
+ bool found_start_{false};
+ bool found_end_{false};
+};
+
+/*! \brief This class partitions the expr labeled with begin and end annoations
+ * into function containing multiple regions. Each region is labeled with
+ * a compiler attribute so that it will be handled by any compilers that are not
+ * in the TVM stack.
+ *
+ * TODO(@zhiics) This following algorithm is not adequate to handle all cases,
+ * i.e. multiple `compiler_end` nodes.
+ */
+class Partitioner : public ExprMutator {
+ public:
+ std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
+ for (auto candidate : this->subgraphs_) {
+ if (candidate->nodes.find(node) != candidate->nodes.end()) {
+ return candidate;
+ }
+ }
+ return nullptr;
+ }
+
+ void MergeSubgraph(std::shared_ptr<Subgraph> subgraph1,
+ std::shared_ptr<Subgraph> subgraph2) {
+ if (subgraph1 == subgraph2) {
+ return;
+ }
+
+ // Merge subgraph 2 to subgraph 1 and erase subgraph 2.
+ subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end());
+ for (auto arg : subgraph2->args) {
+ subgraph1->args.push_back(arg);
+ }
+ this->subgraphs_.erase(subgraph2);
+ }
+
+ void AddToSubgraph(std::shared_ptr<Subgraph> subgraph, const Expr expr) {
+ auto subgraph2 = GetSubgraph(expr);
+ if (subgraph2) {
+ MergeSubgraph(subgraph, subgraph2);
+ } else {
+ subgraph->nodes.insert(expr);
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* call) final {
+ auto op_node = call->op.as<OpNode>();
+
+ if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+ // Propogate subgraph to arguments
+ auto subgraph = GetSubgraph(GetRef<Call>(call));
+ if (subgraph) {
+ for (auto arg : call->args) {
+ AddToSubgraph(subgraph, arg);
+ }
+ }
+ return ExprMutator::VisitExpr_(call);
+ } else if (call->op == compiler_begin_op) {
+ // The annotation node is inserted on edge so it must have only one argument.
+ CHECK_EQ(call->args.size(), 1U);
+
+ // Traverse the rest graph.
+ auto input_expr = VisitExpr(call->args[0]);
+
+ // Replace the begin annotation with an external call input variable.
+ auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+ auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
+ input_expr->checked_type_);
+
+ // Find the corresponding subgraph and add the argument.
+ auto subgraph = GetSubgraph(GetRef<Call>(call));
+ if (!subgraph) {
+ throw Error(ErrorBuilder()
+ << "Cannot find the corresponding subgraph for start annotation:\n"
+ << AsText(GetRef<Call>(call), false));
+ }
+ subgraph->args.push_back({var, input_expr});
+ return std::move(var);
+ } else {
+ CHECK_EQ(call->op, compiler_end_op);
+ // The annotation node is inserted on edge so it must have only one argument.
+ CHECK_EQ(call->args.size(), 1U);
+
+ auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+
+ // Check if the argument already belongs to an exist subgraph
+ auto subgraph = GetSubgraph(call->args[0]);
+ if (!subgraph) {
+ auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
+ subgraph = *ret.first;
+ subgraph->nodes.insert(call->args[0]);
+ subgraph->id = this->subgraph_id_++;
+ }
+ subgraph->nodes.insert(GetRef<Call>(call));
+
+ // Traverse subgraph inputs.
+ auto input = VisitExpr(call->args[0]);
+ Array<Var> params;
+ Array<Expr> args;
+
+ // The subgraph may be merged so we need to update it again.
+ subgraph = GetSubgraph(GetRef<Call>(call));
+ CHECK(subgraph);
+
+ for (auto pair : subgraph->args) {
+ params.push_back(pair.first);
+ args.push_back(pair.second);
+ }
+
+ auto subgraph_func =
+ FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
+
+ Expr arg0 = call->args[0];
+ std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
+ subgraph_func =
+ FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name));
+ subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
+ subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
+ tvm::ir::StringImmNode::make(compiler_attrs->compiler));
+ return CallNode::make(subgraph_func, args);
+ }
+ }
+
+ Expr VisitExpr_(const TupleNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<Tuple>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ for (auto field : op->fields) {
+ AddToSubgraph(subgraph, field);
+ }
+ Array<Expr> fields;
+ for (auto field : op->fields) {
+ fields.push_back(VisitExpr(field));
+ }
+ return TupleNode::make(fields);
+ }
+ }
+
+ Expr VisitExpr_(const TupleGetItemNode* g) final {
+ auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(g);
+ } else {
+ AddToSubgraph(subgraph, g->tuple);
+ auto t = VisitExpr(g->tuple);
+ return TupleGetItemNode::make(t, g->index);
+ }
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<Function>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ Array<Var> params;
+ for (auto param : op->params) {
+ AddToSubgraph(subgraph, param);
+ }
+ for (auto param : op->params) {
+ Var new_param = Downcast<Var>(VisitExpr(param));
+ params.push_back(new_param);
+ }
+ auto body = VisitExpr(op->body);
+ return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs);
+ }
+ }
+
+ Expr VisitExpr_(const LetNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<Let>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ AddToSubgraph(subgraph, op->var);
+ AddToSubgraph(subgraph, op->value);
+ AddToSubgraph(subgraph, op->body);
+ Var var = Downcast<Var>(VisitExpr(op->var));
+ auto value = VisitExpr(op->value);
+ auto body = VisitExpr(op->body);
+
+ return LetNode::make(var, value, body);
+ }
+ }
+
+ Expr VisitExpr_(const IfNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<If>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ AddToSubgraph(subgraph, op->cond);
+ AddToSubgraph(subgraph, op->true_branch);
+ AddToSubgraph(subgraph, op->false_branch);
+ auto guard = VisitExpr(op->cond);
+ auto true_b = VisitExpr(op->true_branch);
+ auto false_b = VisitExpr(op->false_branch);
+ return IfNode::make(guard, true_b, false_b);
+ }
+ }
+
+ Expr VisitExpr_(const RefCreateNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<RefCreate>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ AddToSubgraph(subgraph, op->value);
+ Expr value = VisitExpr(op->value);
+ return RefCreateNode::make(value);
+ }
+ }
+
+ Expr VisitExpr_(const RefReadNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<RefRead>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ AddToSubgraph(subgraph, op->ref);
+ Expr ref = VisitExpr(op->ref);
+ return RefReadNode::make(ref);
+ }
+ }
+
+ Expr VisitExpr_(const RefWriteNode* op) final {
+ auto subgraph = GetSubgraph(GetRef<RefWrite>(op));
+ if (!subgraph) {
+ return ExprMutator::VisitExpr_(op);
+ } else {
+ AddToSubgraph(subgraph, op->ref);
+ Expr ref = VisitExpr(op->ref);
+ Expr value = VisitExpr(op->value);
+ return RefWriteNode::make(ref, value);
+ }
+ }
+
+ private:
+ int var_id_{0};
+ int subgraph_id_{0};
+ std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
+};
+
+/*!
+ * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
+ * the same codegen backend. This reduces rounds trips between TVM and external
+ * backends. Likely we can borrow some ideas from operator fusion.
+ *
+ * For example, sg1 and sg2 should be combined if they belong to the same
+ * codegen tool in the following case.
+ *
+ * op1
+ * / \
+ * sg1 sg2
+ *
+ * |
+ * \|/
+ *
+ * op1
+ * |
+ * sg1_sg2
+ *
+ * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
+ * inputs that obtained from the tuple.
+ */
+
+Expr PartitionGraph(const Expr& expr) {
+ Partitioner part;
+ return part.Mutate(expr);
+}
+
+} // namespace partitioning
+
+namespace transform {
+
+Pass PartitionGraph() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(partitioning::PartitionGraph(f));
+ };
+ auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
+ return Sequential({partitioned, InferType()});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
+.set_body_typed(transform::PartitionGraph);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for graph partitioning."""
+import os
+import sys
+import numpy as np
+import pytest
+
+import tvm
+import tvm.relay.testing
+import tvm.relay.transform as transform
+from tvm import relay
+from tvm.contrib import util
+from tvm.relay.annotation import compiler_begin, compiler_end
+from tvm.relay.expr_functor import ExprMutator
+
+# Leverage the pass manager to write a simple white list based annotator
+@transform.function_pass(opt_level=0)
+class WhiteListAnnotator:
+ def __init__(self, op_list, compiler):
+ assert isinstance(op_list, (list, tuple, set))
+ self.op_list = op_list
+ self.compiler = compiler
+
+ def transform_function(self, func, mod, ctx):
+
+ annotator = self
+ class Annotator(tvm.relay.ExprMutator):
+ def visit_call(self, call):
+ op_name = call.op.name
+ if op_name in annotator.op_list:
+ new_args = []
+ for arg in call.args:
+ ann = compiler_begin(super().visit(arg),
+ annotator.compiler)
+ new_args.append(ann)
+ new_call = relay.Call(call.op, new_args, call.attrs,
+ call.type_args)
+ return compiler_end(new_call, annotator.compiler)
+ else:
+ return super().visit_call(call)
+ return Annotator().visit(func)
+
+
+class CcompilerAnnotator(ExprMutator):
+ """
+ A simple annotator that creates the following program:
+ |
+ -- begin --
+ |
+ add
+ |
+ subtract
+ |
+ multiply
+ |
+ -- end --
+ |
+ """
+
+ def __init__(self):
+ super(CcompilerAnnotator, self).__init__()
+ self.in_compiler = 0
+
+ def visit_call(self, call):
+ if call.op.name == "add": # Annotate begin at args
+ if self.in_compiler == 1:
+ lhs = compiler_begin(super().visit(call.args[0]), "ccompiler")
+ rhs = compiler_begin(super().visit(call.args[1]), "ccompiler")
+ op = relay.add(lhs, rhs)
+ self.in_compiler = 2
+ return op
+ elif call.op.name == "subtract":
+ if self.in_compiler == 1:
+ lhs = super().visit(call.args[0])
+ rhs = super().visit(call.args[1])
+ if isinstance(lhs, relay.expr.Var):
+ lhs = compiler_begin(lhs, "ccompiler")
+ if isinstance(rhs, relay.expr.Var):
+ rhs = compiler_begin(rhs, "ccompiler")
+ return relay.subtract(lhs, rhs)
+ elif call.op.name == "multiply": # Annotate end at output
+ self.in_compiler = 1
+ lhs = super().visit(call.args[0])
+ rhs = super().visit(call.args[1])
+ if isinstance(lhs, relay.expr.Var):
+ lhs = compiler_begin(lhs, "ccompiler")
+ if isinstance(rhs, relay.expr.Var):
+ rhs = compiler_begin(rhs, "ccompiler")
+ op = relay.multiply(lhs, rhs)
+ if self.in_compiler == 2:
+ op = compiler_end(op, "ccompiler")
+ self.in_compiler = 0
+ return op
+ return super().visit_call(call)
+
+
+class WholeGraphAnnotator(ExprMutator):
+ """
+ An annotator that creates a compiler for an entire graph.
+ """
+
+ def __init__(self, compiler):
+ super(WholeGraphAnnotator, self).__init__()
+ self.compiler = compiler
+ self.last_call = True
+
+ def visit_call(self, call):
+ curr_last = self.last_call
+ self.last_call = False
+
+ params = []
+ for arg in call.args:
+ param = super().visit(arg)
+ if isinstance(param, relay.expr.Var):
+ param = compiler_begin(param, self.compiler)
+ params.append(param)
+
+ new_call = relay.Call(call.op, params, call.attrs)
+ if curr_last:
+ new_call = compiler_end(new_call, self.compiler)
+ return new_call
+
+
+class MobileNetAnnotator(ExprMutator):
+ """
+ Annotate mobilenet until global_avg_pool.
+ """
+
+ def __init__(self, compiler):
+ super(MobileNetAnnotator, self).__init__()
+ self.compiler = compiler
+ self.compiler_open = False
+
+ def visit_call(self, call):
+
+ if call.op.name == 'nn.global_avg_pool2d':
+ self.compiler_open = True
+ compiler_open = self.compiler_open
+
+ params = []
+ for arg in call.args:
+ param = super().visit(arg)
+ if call.op.name == 'nn.global_avg_pool2d':
+ param = compiler_end(param, self.compiler)
+ if compiler_open and isinstance(param, relay.expr.Var):
+ param = compiler_begin(param, self.compiler)
+ params.append(param)
+
+ new_call = relay.Call(call.op, params, call.attrs)
+ return new_call
+
+
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+ ctx=tvm.cpu(), params=None):
+ if sys.platform == "win32":
+ print("Skip test on Windows for now")
+ return
+
+ def update_lib(lib):
+ test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+ source_dir = os.path.join(test_dir, "..", "..", "..")
+ contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+ kwargs = {}
+ kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+ tmp_path = util.tempdir()
+ lib_name = 'lib.so'
+ lib_path = tmp_path.relpath(lib_name)
+ lib.export_library(lib_path, fcompile=False, **kwargs)
+ lib = tvm.module.load(lib_path)
+
+ return lib
+
+ def check_vm_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ exe = relay.vm.compile(mod, target=target, params=params)
+ code, lib = exe.save()
+ lib = update_lib(lib)
+ exe = relay.vm.Executable.load_exec(code, lib)
+ vm = relay.vm.VirtualMachine(exe)
+ vm.init(ctx)
+ out = vm.run(**map_inputs)
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ def check_graph_runtime_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ json, lib, param = relay.build(mod, target=target, params=params)
+ lib = update_lib(lib)
+ rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ for name, data in map_inputs.items():
+ rt_mod.set_input(name, data)
+ rt_mod.set_input(**param)
+ rt_mod.run()
+ out = tvm.nd.empty(out_shape, ctx=ctx)
+ out = rt_mod.get_output(0, out)
+
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ check_vm_result()
+ check_graph_runtime_result()
+
+
+def test_multi_node_compiler():
+ x = relay.var('x', shape=(10, 10))
+ w0 = relay.var('w0', shape=(10, 10))
+ w1 = relay.var('w1', shape=(10, 10))
+ w2 = relay.var('w2', shape=(10, 10))
+ w3 = relay.var('w3', shape=(10, 10))
+ w4 = relay.var('w4', shape=(10, 10))
+ w5 = relay.var('w5', shape=(10, 10))
+ w6 = relay.var('w6', shape=(10, 10))
+ w7 = relay.var('w7', shape=(10, 10))
+
+ # C compiler
+ # FIXME: We generate two compilers for this case but they should be merged to one
+ # due to the common input (x).
+ z0 = relay.add(x, w0)
+ p0 = relay.subtract(z0, w1)
+ q0 = relay.multiply(p0, w2)
+
+ z1 = relay.add(x, w3)
+ p1 = relay.subtract(z1, w4)
+ q1 = relay.multiply(p1, w5)
+
+ # Other parts on TVM
+ z2 = relay.add(x, w6)
+ q2 = relay.subtract(z2, w7)
+
+ r = relay.concatenate((q0, q1, q2), axis=0)
+ f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
+ mod = relay.Module()
+ ann = CcompilerAnnotator()
+ mod["main"] = ann.visit(f)
+ mod = transform.PartitionGraph()(mod)
+ mod = transform.InferType()(mod)
+
+ x_data = np.random.rand(10, 10).astype('float32')
+ w_data = []
+ for _ in range(8):
+ w_data.append(np.random.rand(10, 10).astype('float32'))
+
+ map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
+ map_inputs["x"] = x_data
+ check_result(
+ mod, map_inputs, (30, 10),
+ np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
+ ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+ x_data + w_data[6] - w_data[7]),
+ axis=0))
+
+
+def test_extern_ccompiler_single_op():
+ @transform.function_pass(opt_level=0)
+ class MyAnnotator:
+ def transform_function(self, func, mod, ctx):
+ class Annotator(tvm.relay.ExprMutator):
+ def visit_call(self, call):
+ new_args = []
+ for arg in call.args:
+ ann = compiler_begin(self.visit(arg), "ccompiler")
+ new_args.append(ann)
+ new_call = relay.Call(call.op, new_args)
+ return compiler_end(new_call, "ccompiler")
+ return Annotator().visit(func)
+
+ x = relay.var('x', shape=(8, 8))
+ y = relay.var('y', shape=(8, 8))
+ z = x + y
+ f = relay.Function([x, y], z)
+ x_data = np.random.rand(8, 8).astype('float32')
+ y_data = np.random.rand(8, 8).astype('float32')
+ mod = relay.Module()
+ mod["main"] = f
+ mod = MyAnnotator()(mod)
+ mod = transform.PartitionGraph()(mod)
+
+ check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
+
+
+def test_extern_ccompiler_default_ops():
+ def expected():
+ x = relay.var("x", shape=(8, 8))
+ y = relay.var("y", shape=(8, 8))
+ x0 = relay.var("x0", shape=(8, 8))
+ y0 = relay.var("y0", shape=(8, 8))
+ add = x0 + y0
+ # Function that uses C compiler
+ func = relay.Function([x0, y0], add)
+ func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
+ func = func.set_attribute("Compiler",
+ tvm.expr.StringImm("ccompiler"))
+ func = func.set_attribute("ExternalSymbol",
+ tvm.expr.StringImm("ccompiler_0"))
+ add_call = relay.Call(func, [x, y])
+ # Function that uses default compiler. Ops are fused in this function.
+ p0 = relay.var("p0", shape=(8, 8))
+ log = relay.log(p0)
+ exp = relay.exp(p0)
+ concat = relay.concatenate([log, exp], axis=0)
+ fused_func = relay.Function([p0], concat)
+ fused_func = fused_func.set_attribute("Primitive",
+ tvm.expr.IntImm("int32", 1))
+ fused_call = relay.Call(fused_func, [add_call])
+ main = relay.Function([x, y], fused_call)
+ mod = relay.Module()
+ mod["main"] = main
+ return mod
+
+ x = relay.var("x", shape=(8, 8))
+ y = relay.var("y", shape=(8, 8))
+ add = x + y
+ log = relay.log(add)
+ exp = relay.exp(add)
+ concat = relay.concatenate([log, exp], axis=0)
+ f = relay.Function([x, y], concat)
+ mod = relay.Module()
+ mod["main"] = f
+ mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
+ mod = transform.PartitionGraph()(mod)
+
+ fused_mod = transform.FuseOps(2)(mod)
+ expected_mod = expected()
+ assert relay.alpha_equal(fused_mod, expected_mod)
+
+ x_data = np.random.rand(8, 8).astype('float32')
+ y_data = np.random.rand(8, 8).astype('float32')
+ np_add = x_data + y_data
+ res = np.concatenate([np.log(np_add), np.exp(np_add)])
+ check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
+
+
+def test_extern_ccompiler():
+ x = relay.var('x', shape=(2, 2))
+ y = relay.var('y', shape=(2, 2))
+ z = x + x
+ p = y * y
+ f = relay.Function([x, y], p - z)
+ x_data = np.random.rand(2, 2).astype('float32')
+ y_data = np.random.rand(2, 2).astype('float32')
+ mod = relay.Module()
+ mod["main"] = f
+ mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
+ mod = transform.PartitionGraph()(mod)
+
+ check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
+
+
+def test_extern_dnnl():
+ if not tvm.get_global_func("relay.ext.dnnl", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 1, 3, 3)
+ data = relay.var('data', shape=(ishape), dtype=dtype)
+ weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+ depthwise_conv2d_1 = relay.nn.conv2d(data,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+ f = relay.Function([data, weight1], out)
+
+ mod = relay.Module()
+ mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
+ mod = transform.PartitionGraph()(mod)
+
+ ref_mod = relay.Module()
+ ref_mod['main'] = f
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+ ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
+ ref_res = ref_ex.evaluate()(i_data, w1_data)
+ check_result(mod, {"data": i_data, "weight1": w1_data},
+ (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+
+
+def test_extern_dnnl_mobilenet():
+ if not tvm.get_global_func("relay.ext.dnnl", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (1, 3, 224, 224)
+ mod, params = relay.testing.mobilenet.get_workload(
+ batch_size=1, dtype='float32')
+
+ op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
+ mod = WhiteListAnnotator(op_list, "dnnl")(mod)
+ mod = transform.PartitionGraph()(mod)
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+
+ ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
+ dtype='float32')
+ ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+ ref_res = ref_ex.evaluate()(i_data, **params)
+
+ check_result(mod, {"data": i_data},
+ (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+
+
+if __name__ == "__main__":
+ test_multi_node_compiler()
+ test_extern_ccompiler_single_op()
+ test_extern_ccompiler_default_ops()
+ test_extern_ccompiler()
+ test_extern_dnnl()
+ test_extern_dnnl_mobilenet()