[Relay] Target annotation for external codegen (#4933)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 3 Mar 2020 09:30:28 +0000 (01:30 -0800)
committerGitHub <noreply@github.com>
Tue, 3 Mar 2020 09:30:28 +0000 (18:30 +0900)
* op based external compiler annotation

* Use TVM register directly

* Small fix

* test graph

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
include/tvm/relay/op_attr_types.h
python/tvm/relay/op/__init__.py
python/tvm/relay/op/contrib/__init__.py
python/tvm/relay/op/contrib/contrib.py [deleted file]
python/tvm/relay/op/contrib/dnnl.py [new file with mode: 0644]
python/tvm/relay/op/op.py
python/tvm/relay/transform.py
src/relay/pass/annotate_target.cc [new file with mode: 0644]
tests/python/relay/test_annotate_target.py [new file with mode: 0644]

index 1a2263e..5b2fdd3 100644 (file)
@@ -180,6 +180,20 @@ using FTVMLegalize = runtime::TypedPackedFunc<
        const Array<tvm::relay::Type>& arg_types)>;
 
 /*!
+ * \brief Annotates an expression to indicate if an op should be compiled using
+ * the given compiler/target.
+ *
+ * \param attrs The attribute of the original expr.
+ * \param args The arguments of the original expr.
+ *
+ * \return true if this op should be registered to invoke a specific compiler
+ * for codegen, otherwise, false.
+ */
+using FTVMAnnotateTarget = runtime::TypedPackedFunc<
+  bool(const Attrs& attrs,  // NOLINT(*)
+       const Array<Expr>& args)>;
+
+/*!
  * \brief Forward rewriting rule for a specific op.
  *
  * \param ref_call The reference old call type to be rewritten.
index 4a4823d..1a1d0d3 100644 (file)
@@ -19,7 +19,7 @@
 # operator defs
 from .op import get, register, register_compute, register_gradient, \
     register_pattern, register_alter_op_layout, register_legalize, \
-    Op, OpPattern, OpStrategy, debug
+    Op, OpPattern, OpStrategy, debug, register_external_compiler
 from . import strategy
 
 # Operators
index c6e086a..4b6acce 100644 (file)
@@ -15,5 +15,5 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=wildcard-import
-"""Neural network related operators."""
-from __future__ import absolute_import as _abs
+"""Contrib modules."""
+from .dnnl import *
diff --git a/python/tvm/relay/op/contrib/contrib.py b/python/tvm/relay/op/contrib/contrib.py
deleted file mode 100644 (file)
index cb7e5d4..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#pylint: disable=invalid-name, too-many-lines
-"""Contrib operations."""
-from __future__ import absolute_import as _abs
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
new file mode 100644 (file)
index 0000000..1aa7192
--- /dev/null
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""DNNL library supported operators.
+There are two ways to registering a function for an op to indicate if it is
+supported by DNNL.
+
+- The first and simplest way is to use the helper so that
+users only need to provide the operator name and a boolean value to indicate if
+it is supported. For example:
+
+    .. code-block:: python
+
+      add = _register_external_op_helper("add")
+      add = _register_external_op_helper("add", True)
+      add = _register_external_op_helper("add", False)
+
+- The other way is to implement the function by themselves to
+check the attributes of the op and decide if it should be offloaded to DNNL.
+"""
+from ... import op as reg
+
+
+def _register_external_op_helper(op_name, supported=True):
+    """The helper function to indicate that a given operator can be supported
+    by DNNL.
+
+    Paramters
+    ---------
+    op_name : Str
+        The name of operator that will be registered.
+
+    Returns
+    -------
+    f : callable
+        A function that returns if the operator is supported by DNNL.
+    """
+    @reg.register(op_name, "target.dnnl")
+    def _func_wrapper(attrs, args):
+        return supported
+
+    return _func_wrapper
+
+
+_register_external_op_helper("nn.conv2d")
+_register_external_op_helper("nn.dense")
+_register_external_op_helper("nn.relu")
+_register_external_op_helper("add")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+
+
+@reg.register("nn.batch_norm", "target.dnnl")
+def batch_norm(attrs, args):
+    """Check if the external DNNL codegen should be used.
+    FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
+    """
+    return False
index 6be7d4d..4cd4b2a 100644 (file)
@@ -453,14 +453,36 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
     get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
     return register(op_name, "FShapeFunc", shape_func, level)
 
+
+def register_external_compiler(op_name, fexternal=None, level=10):
+    """Register the external compiler for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the operator.
+
+    fexternal : function (attrs: Attrs, args: List[Expr], compiler: str)
+              -> new_expr: Expr
+        The function for wrapping a call expr with compiler_begin and
+        compiler_end.
+
+    level : int
+        The priority level
+    """
+    return register(op_name, "FTVMExternalCompiler", fexternal, level)
+
+
 @tvm._ffi.register_func("relay.op.compiler._lower")
 def _lower(name, schedule, inputs, outputs):
     return lower(schedule, list(inputs) + list(outputs), name=name)
 
+
 @tvm._ffi.register_func("relay.op.compiler._build")
 def _build(lowered_funcs):
     return build(lowered_funcs, target="llvm")
 
+
 _schedule_injective = None
 _schedule_reduce = None
 
index c54e4c8..b2565f3 100644 (file)
@@ -552,6 +552,25 @@ def PartitionGraph():
     return _transform.PartitionGraph()
 
 
+
+def AnnotateTarget(target):
+    """Annotate ops in an experession with a provied compiler/target and then
+    use it for codegen.
+
+    Parameters
+    ----------
+    target : String
+        The target compiler used for codegen.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The annotated pass that wrapps ops with subgraph_start and
+        subgraph_end.
+    """
+    return _transform.AnnotateTarget(target)
+
+
 def Inline():
     """Perform inlining on the given Relay IR module. The global functions that
     are marked as `inline` should be always inlined. A cost model will be
diff --git a/src/relay/pass/annotate_target.cc b/src/relay/pass/annotate_target.cc
new file mode 100644 (file)
index 0000000..7322069
--- /dev/null
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/annotate_target.cc
+ * \brief Wraps a call with compiler_begin and compiler_end to indicate that
+ * the op of this call node will use external compiler.
+ */
+
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace annotate_target {
+
+// A helper class to insert annotation boundaries for a program region that will
+// be handled by a specific compiler.
+class AnnotateTargetWrapper : public ExprMutator {
+ public:
+  explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
+
+  Expr VisitExpr_(const CallNode* cn) {
+    // TODO(@zhiics, @comaniac) Handle composite functions.
+    auto new_e = ExprMutator::VisitExpr_(cn);
+
+    Call call = Downcast<Call>(new_e);
+    static auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
+    Op op = Downcast<Op>(call->op);
+    CHECK(op.defined());
+
+    if (fannotate.count(op)) {
+      bool external = fannotate[op](call->attrs, call->args);
+      if (external) {
+        tvm::Array<tvm::relay::Expr> compiler_begins;
+        for (const auto& it : call->args) {
+          const auto* begin_op =
+            runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+          CHECK(begin_op);
+          Expr begin = (*begin_op)(it, target_);
+          compiler_begins.push_back(begin);
+        }
+        Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs);
+        const auto* end_op =
+          runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+        CHECK(end_op);
+        Expr end = (*end_op)(update_call, target_);
+        return end;
+      }
+    } else {
+      LOG(WARNING) << op->name << " in " << target_
+                   << " is not registered. It will be executed on CPU.";
+    }
+    return new_e;
+  }
+
+ private:
+  std::string target_;
+};
+
+Expr AnnotateTarget(const Expr& expr, const std::string& target) {
+  return AnnotateTargetWrapper(target).Mutate(expr);
+}
+
+}  // namespace annotate_target
+
+namespace transform {
+
+Pass AnnotateTarget(const std::string& target) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
+      };
+  auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
+                                      {tir::StringImmNode::make("InferType")});
+  return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget")
+.set_body_typed(AnnotateTarget);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py
new file mode 100644 (file)
index 0000000..f4e602a
--- /dev/null
@@ -0,0 +1,188 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for annotating external targets."""
+import os
+import sys
+import numpy as np
+import pytest
+
+import tvm
+import tvm.relay.testing
+import tvm.relay.transform as transform
+from tvm import relay
+from tvm import runtime
+from tvm.contrib import util
+
+
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+                 ctx=tvm.cpu(), params=None):
+    if sys.platform == "win32":
+        print("Skip test on Windows for now")
+        return
+
+    def update_lib(lib):
+        test_dir = os.path.dirname(
+            os.path.realpath(os.path.expanduser(__file__)))
+        source_dir = os.path.join(test_dir, "..", "..", "..")
+        contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+        kwargs = {}
+        kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+        tmp_path = util.tempdir()
+        lib_name = 'lib.so'
+        lib_path = tmp_path.relpath(lib_name)
+        lib.export_library(lib_path, fcompile=False, **kwargs)
+        lib = runtime.load_module(lib_path)
+
+        return lib
+
+    def check_vm_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            exe = relay.vm.compile(mod, target=target, params=params)
+        code, lib = exe.save()
+        lib = update_lib(lib)
+        exe = runtime.vm.Executable.load_exec(code, lib)
+        vm = runtime.vm.VirtualMachine(exe)
+        vm.init(ctx)
+        out = vm.run(**map_inputs)
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    def check_graph_runtime_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            json, lib, param = relay.build(mod, target=target, params=params)
+        lib = update_lib(lib)
+        rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+        for name, data in map_inputs.items():
+            rt_mod.set_input(name, data)
+        rt_mod.set_input(**param)
+        rt_mod.run()
+        out = tvm.nd.empty(out_shape, ctx=ctx)
+        out = rt_mod.get_output(0, out)
+
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    check_vm_result()
+    check_graph_runtime_result()
+
+
+def test_extern_dnnl():
+    def annotated(dtype, ishape, w1shape):
+        data = relay.var('data', shape=(ishape), dtype=dtype)
+        weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+        depthwise_conv2d_1 = relay.nn.conv2d(data,
+                                             weight1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+                                             weight1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+        f = relay.Function([data, weight1], out)
+
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    def expected(dtype, ishape, w1shape):
+        data = relay.var('data', shape=(ishape), dtype=dtype)
+        weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
+        begin0 = relay.annotation.compiler_begin(data, "dnnl")
+        begin1 = relay.annotation.compiler_begin(weight1, "dnnl")
+        depthwise_conv2d_1 = relay.nn.conv2d(begin0,
+                                             begin1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
+        begin2 = relay.annotation.compiler_begin(end0, "dnnl")
+        begin3 = relay.annotation.compiler_begin(end0, "dnnl")
+        begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
+        depthwise_conv2d_2 = relay.nn.conv2d(begin3,
+                                             begin4,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
+        begin5 = relay.annotation.compiler_begin(end1, "dnnl")
+        out = relay.add(begin2, begin5)
+        end2 = relay.annotation.compiler_end(out, "dnnl")
+        f = relay.Function([data, weight1], end2)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    dtype = "float32"
+    ishape = (1, 32, 14, 14)
+    w1shape = (32, 1, 3, 3)
+
+    def test_annotate():
+        mod = annotated(dtype, ishape, w1shape)
+        mod = transform.AnnotateTarget("dnnl")(mod)
+        ref_mod = expected(dtype, ishape, w1shape)
+        assert relay.analysis.alpha_equal(mod, ref_mod)
+
+    def test_run():
+        if not tvm.get_global_func("relay.ext.dnnl", True):
+            print("skip because DNNL codegen is not available")
+            return
+
+        ref_mod = annotated(dtype, ishape, w1shape)
+        mod = annotated(dtype, ishape, w1shape)
+        mod = transform.PartitionGraph()(mod)
+
+        i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+        w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+        ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
+        ref_res = ref_ex.evaluate()(i_data, w1_data)
+
+        check_result(mod, {"data": i_data, "weight1": w1_data},
+                     (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+
+    test_annotate()
+    test_run()
+
+
+def test_extern_dnnl_mobilenet():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    dtype = 'float32'
+    ishape = (1, 3, 224, 224)
+    mod, params = relay.testing.mobilenet.get_workload(
+        batch_size=1, dtype='float32')
+
+    mod = transform.AnnotateTarget("dnnl")(mod)
+    mod = transform.PartitionGraph()(mod)
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+
+    ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
+                                                           dtype='float32')
+    ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+    ref_res = ref_ex.evaluate()(i_data, **params)
+
+    check_result(mod, {"data": i_data},
+                 (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+
+
+if __name__ == "__main__":
+    test_extern_dnnl()
+    test_extern_dnnl_mobilenet()