const Array<tvm::relay::Type>& arg_types)>;
/*!
+ * \brief Annotates an expression to indicate if an op should be compiled using
+ * the given compiler/target.
+ *
+ * \param attrs The attribute of the original expr.
+ * \param args The arguments of the original expr.
+ *
+ * \return true if this op should be registered to invoke a specific compiler
+ * for codegen, otherwise, false.
+ */
+using FTVMAnnotateTarget = runtime::TypedPackedFunc<
+ bool(const Attrs& attrs, // NOLINT(*)
+ const Array<Expr>& args)>;
+
+/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
# operator defs
from .op import get, register, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
- Op, OpPattern, OpStrategy, debug
+ Op, OpPattern, OpStrategy, debug, register_external_compiler
from . import strategy
# Operators
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
-"""Neural network related operators."""
-from __future__ import absolute_import as _abs
+"""Contrib modules."""
+from .dnnl import *
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#pylint: disable=invalid-name, too-many-lines
-"""Contrib operations."""
-from __future__ import absolute_import as _abs
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""DNNL library supported operators.
+There are two ways to registering a function for an op to indicate if it is
+supported by DNNL.
+
+- The first and simplest way is to use the helper so that
+users only need to provide the operator name and a boolean value to indicate if
+it is supported. For example:
+
+ .. code-block:: python
+
+ add = _register_external_op_helper("add")
+ add = _register_external_op_helper("add", True)
+ add = _register_external_op_helper("add", False)
+
+- The other way is to implement the function by themselves to
+check the attributes of the op and decide if it should be offloaded to DNNL.
+"""
+from ... import op as reg
+
+
+def _register_external_op_helper(op_name, supported=True):
+ """The helper function to indicate that a given operator can be supported
+ by DNNL.
+
+ Paramters
+ ---------
+ op_name : Str
+ The name of operator that will be registered.
+
+ Returns
+ -------
+ f : callable
+ A function that returns if the operator is supported by DNNL.
+ """
+ @reg.register(op_name, "target.dnnl")
+ def _func_wrapper(attrs, args):
+ return supported
+
+ return _func_wrapper
+
+
+_register_external_op_helper("nn.conv2d")
+_register_external_op_helper("nn.dense")
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("add")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+
+
+@reg.register("nn.batch_norm", "target.dnnl")
+def batch_norm(attrs, args):
+ """Check if the external DNNL codegen should be used.
+ FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
+ """
+ return False
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
+
+def register_external_compiler(op_name, fexternal=None, level=10):
+ """Register the external compiler for an op.
+
+ Parameters
+ ----------
+ op_name : str
+ The name of the operator.
+
+ fexternal : function (attrs: Attrs, args: List[Expr], compiler: str)
+ -> new_expr: Expr
+ The function for wrapping a call expr with compiler_begin and
+ compiler_end.
+
+ level : int
+ The priority level
+ """
+ return register(op_name, "FTVMExternalCompiler", fexternal, level)
+
+
@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
+
@tvm._ffi.register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
+
_schedule_injective = None
_schedule_reduce = None
return _transform.PartitionGraph()
+
+def AnnotateTarget(target):
+ """Annotate ops in an experession with a provied compiler/target and then
+ use it for codegen.
+
+ Parameters
+ ----------
+ target : String
+ The target compiler used for codegen.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The annotated pass that wrapps ops with subgraph_start and
+ subgraph_end.
+ """
+ return _transform.AnnotateTarget(target)
+
+
def Inline():
"""Perform inlining on the given Relay IR module. The global functions that
are marked as `inline` should be always inlined. A cost model will be
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/annotate_target.cc
+ * \brief Wraps a call with compiler_begin and compiler_end to indicate that
+ * the op of this call node will use external compiler.
+ */
+
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace annotate_target {
+
+// A helper class to insert annotation boundaries for a program region that will
+// be handled by a specific compiler.
+class AnnotateTargetWrapper : public ExprMutator {
+ public:
+ explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
+
+ Expr VisitExpr_(const CallNode* cn) {
+ // TODO(@zhiics, @comaniac) Handle composite functions.
+ auto new_e = ExprMutator::VisitExpr_(cn);
+
+ Call call = Downcast<Call>(new_e);
+ static auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
+ Op op = Downcast<Op>(call->op);
+ CHECK(op.defined());
+
+ if (fannotate.count(op)) {
+ bool external = fannotate[op](call->attrs, call->args);
+ if (external) {
+ tvm::Array<tvm::relay::Expr> compiler_begins;
+ for (const auto& it : call->args) {
+ const auto* begin_op =
+ runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+ CHECK(begin_op);
+ Expr begin = (*begin_op)(it, target_);
+ compiler_begins.push_back(begin);
+ }
+ Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs);
+ const auto* end_op =
+ runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+ CHECK(end_op);
+ Expr end = (*end_op)(update_call, target_);
+ return end;
+ }
+ } else {
+ LOG(WARNING) << op->name << " in " << target_
+ << " is not registered. It will be executed on CPU.";
+ }
+ return new_e;
+ }
+
+ private:
+ std::string target_;
+};
+
+Expr AnnotateTarget(const Expr& expr, const std::string& target) {
+ return AnnotateTargetWrapper(target).Mutate(expr);
+}
+
+} // namespace annotate_target
+
+namespace transform {
+
+Pass AnnotateTarget(const std::string& target) {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
+ };
+ auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
+ {tir::StringImmNode::make("InferType")});
+ return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget")
+.set_body_typed(AnnotateTarget);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for annotating external targets."""
+import os
+import sys
+import numpy as np
+import pytest
+
+import tvm
+import tvm.relay.testing
+import tvm.relay.transform as transform
+from tvm import relay
+from tvm import runtime
+from tvm.contrib import util
+
+
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+ ctx=tvm.cpu(), params=None):
+ if sys.platform == "win32":
+ print("Skip test on Windows for now")
+ return
+
+ def update_lib(lib):
+ test_dir = os.path.dirname(
+ os.path.realpath(os.path.expanduser(__file__)))
+ source_dir = os.path.join(test_dir, "..", "..", "..")
+ contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+ kwargs = {}
+ kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+ tmp_path = util.tempdir()
+ lib_name = 'lib.so'
+ lib_path = tmp_path.relpath(lib_name)
+ lib.export_library(lib_path, fcompile=False, **kwargs)
+ lib = runtime.load_module(lib_path)
+
+ return lib
+
+ def check_vm_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ exe = relay.vm.compile(mod, target=target, params=params)
+ code, lib = exe.save()
+ lib = update_lib(lib)
+ exe = runtime.vm.Executable.load_exec(code, lib)
+ vm = runtime.vm.VirtualMachine(exe)
+ vm.init(ctx)
+ out = vm.run(**map_inputs)
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ def check_graph_runtime_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ json, lib, param = relay.build(mod, target=target, params=params)
+ lib = update_lib(lib)
+ rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ for name, data in map_inputs.items():
+ rt_mod.set_input(name, data)
+ rt_mod.set_input(**param)
+ rt_mod.run()
+ out = tvm.nd.empty(out_shape, ctx=ctx)
+ out = rt_mod.get_output(0, out)
+
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ check_vm_result()
+ check_graph_runtime_result()
+
+
+def test_extern_dnnl():
+ def annotated(dtype, ishape, w1shape):
+ data = relay.var('data', shape=(ishape), dtype=dtype)
+ weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+ depthwise_conv2d_1 = relay.nn.conv2d(data,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+ f = relay.Function([data, weight1], out)
+
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ def expected(dtype, ishape, w1shape):
+ data = relay.var('data', shape=(ishape), dtype=dtype)
+ weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+ begin0 = relay.annotation.compiler_begin(data, "dnnl")
+ begin1 = relay.annotation.compiler_begin(weight1, "dnnl")
+ depthwise_conv2d_1 = relay.nn.conv2d(begin0,
+ begin1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
+ begin2 = relay.annotation.compiler_begin(end0, "dnnl")
+ begin3 = relay.annotation.compiler_begin(end0, "dnnl")
+ begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
+ depthwise_conv2d_2 = relay.nn.conv2d(begin3,
+ begin4,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
+ begin5 = relay.annotation.compiler_begin(end1, "dnnl")
+ out = relay.add(begin2, begin5)
+ end2 = relay.annotation.compiler_end(out, "dnnl")
+ f = relay.Function([data, weight1], end2)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ dtype = "float32"
+ ishape = (1, 32, 14, 14)
+ w1shape = (32, 1, 3, 3)
+
+ def test_annotate():
+ mod = annotated(dtype, ishape, w1shape)
+ mod = transform.AnnotateTarget("dnnl")(mod)
+ ref_mod = expected(dtype, ishape, w1shape)
+ assert relay.analysis.alpha_equal(mod, ref_mod)
+
+ def test_run():
+ if not tvm.get_global_func("relay.ext.dnnl", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ ref_mod = annotated(dtype, ishape, w1shape)
+ mod = annotated(dtype, ishape, w1shape)
+ mod = transform.PartitionGraph()(mod)
+
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+ w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+ ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
+ ref_res = ref_ex.evaluate()(i_data, w1_data)
+
+ check_result(mod, {"data": i_data, "weight1": w1_data},
+ (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+
+ test_annotate()
+ test_run()
+
+
+def test_extern_dnnl_mobilenet():
+ if not tvm.get_global_func("relay.ext.dnnl", True):
+ print("skip because DNNL codegen is not available")
+ return
+
+ dtype = 'float32'
+ ishape = (1, 3, 224, 224)
+ mod, params = relay.testing.mobilenet.get_workload(
+ batch_size=1, dtype='float32')
+
+ mod = transform.AnnotateTarget("dnnl")(mod)
+ mod = transform.PartitionGraph()(mod)
+ i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+
+ ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
+ dtype='float32')
+ ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+ ref_res = ref_ex.evaluate()(i_data, **params)
+
+ check_result(mod, {"data": i_data},
+ (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+
+
+if __name__ == "__main__":
+ test_extern_dnnl()
+ test_extern_dnnl_mobilenet()