From 98b1759052c2dacb38b6d3e0bbdba38002bbef75 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 3 Mar 2020 01:30:28 -0800 Subject: [PATCH] [Relay] Target annotation for external codegen (#4933) * op based external compiler annotation * Use TVM register directly * Small fix * test graph Co-authored-by: Cody Yu --- include/tvm/relay/op_attr_types.h | 14 +++ python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/contrib/__init__.py | 4 +- python/tvm/relay/op/contrib/contrib.py | 19 --- python/tvm/relay/op/contrib/dnnl.py | 72 +++++++++++ python/tvm/relay/op/op.py | 22 ++++ python/tvm/relay/transform.py | 19 +++ src/relay/pass/annotate_target.cc | 103 ++++++++++++++++ tests/python/relay/test_annotate_target.py | 188 +++++++++++++++++++++++++++++ 9 files changed, 421 insertions(+), 22 deletions(-) delete mode 100644 python/tvm/relay/op/contrib/contrib.py create mode 100644 python/tvm/relay/op/contrib/dnnl.py create mode 100644 src/relay/pass/annotate_target.cc create mode 100644 tests/python/relay/test_annotate_target.py diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 1a2263e..5b2fdd3 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -180,6 +180,20 @@ using FTVMLegalize = runtime::TypedPackedFunc< const Array& arg_types)>; /*! + * \brief Annotates an expression to indicate if an op should be compiled using + * the given compiler/target. + * + * \param attrs The attribute of the original expr. + * \param args The arguments of the original expr. + * + * \return true if this op should be registered to invoke a specific compiler + * for codegen, otherwise, false. + */ +using FTVMAnnotateTarget = runtime::TypedPackedFunc< + bool(const Attrs& attrs, // NOLINT(*) + const Array& args)>; + +/*! * \brief Forward rewriting rule for a specific op. * * \param ref_call The reference old call type to be rewritten. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 4a4823d..1a1d0d3 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -19,7 +19,7 @@ # operator defs from .op import get, register, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - Op, OpPattern, OpStrategy, debug + Op, OpPattern, OpStrategy, debug, register_external_compiler from . import strategy # Operators diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index c6e086a..4b6acce 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=wildcard-import -"""Neural network related operators.""" -from __future__ import absolute_import as _abs +"""Contrib modules.""" +from .dnnl import * diff --git a/python/tvm/relay/op/contrib/contrib.py b/python/tvm/relay/op/contrib/contrib.py deleted file mode 100644 index cb7e5d4..0000000 --- a/python/tvm/relay/op/contrib/contrib.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -#pylint: disable=invalid-name, too-many-lines -"""Contrib operations.""" -from __future__ import absolute_import as _abs diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py new file mode 100644 index 0000000..1aa7192 --- /dev/null +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""DNNL library supported operators. +There are two ways to registering a function for an op to indicate if it is +supported by DNNL. + +- The first and simplest way is to use the helper so that +users only need to provide the operator name and a boolean value to indicate if +it is supported. For example: + + .. code-block:: python + + add = _register_external_op_helper("add") + add = _register_external_op_helper("add", True) + add = _register_external_op_helper("add", False) + +- The other way is to implement the function by themselves to +check the attributes of the op and decide if it should be offloaded to DNNL. +""" +from ... import op as reg + + +def _register_external_op_helper(op_name, supported=True): + """The helper function to indicate that a given operator can be supported + by DNNL. + + Paramters + --------- + op_name : Str + The name of operator that will be registered. + + Returns + ------- + f : callable + A function that returns if the operator is supported by DNNL. + """ + @reg.register(op_name, "target.dnnl") + def _func_wrapper(attrs, args): + return supported + + return _func_wrapper + + +_register_external_op_helper("nn.conv2d") +_register_external_op_helper("nn.dense") +_register_external_op_helper("nn.relu") +_register_external_op_helper("add") +_register_external_op_helper("subtract") +_register_external_op_helper("multiply") + + +@reg.register("nn.batch_norm", "target.dnnl") +def batch_norm(attrs, args): + """Check if the external DNNL codegen should be used. + FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs. + """ + return False diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6be7d4d..4cd4b2a 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -453,14 +453,36 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): get(op_name).set_attr("TShapeDataDependant", data_dependant, level) return register(op_name, "FShapeFunc", shape_func, level) + +def register_external_compiler(op_name, fexternal=None, level=10): + """Register the external compiler for an op. + + Parameters + ---------- + op_name : str + The name of the operator. + + fexternal : function (attrs: Attrs, args: List[Expr], compiler: str) + -> new_expr: Expr + The function for wrapping a call expr with compiler_begin and + compiler_end. + + level : int + The priority level + """ + return register(op_name, "FTVMExternalCompiler", fexternal, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) + @tvm._ffi.register_func("relay.op.compiler._build") def _build(lowered_funcs): return build(lowered_funcs, target="llvm") + _schedule_injective = None _schedule_reduce = None diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index c54e4c8..b2565f3 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -552,6 +552,25 @@ def PartitionGraph(): return _transform.PartitionGraph() + +def AnnotateTarget(target): + """Annotate ops in an experession with a provied compiler/target and then + use it for codegen. + + Parameters + ---------- + target : String + The target compiler used for codegen. + + Returns + ------- + ret : tvm.relay.Pass + The annotated pass that wrapps ops with subgraph_start and + subgraph_end. + """ + return _transform.AnnotateTarget(target) + + def Inline(): """Perform inlining on the given Relay IR module. The global functions that are marked as `inline` should be always inlined. A cost model will be diff --git a/src/relay/pass/annotate_target.cc b/src/relay/pass/annotate_target.cc new file mode 100644 index 0000000..7322069 --- /dev/null +++ b/src/relay/pass/annotate_target.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/pass/annotate_target.cc + * \brief Wraps a call with compiler_begin and compiler_end to indicate that + * the op of this call node will use external compiler. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace annotate_target { + +// A helper class to insert annotation boundaries for a program region that will +// be handled by a specific compiler. +class AnnotateTargetWrapper : public ExprMutator { + public: + explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} + + Expr VisitExpr_(const CallNode* cn) { + // TODO(@zhiics, @comaniac) Handle composite functions. + auto new_e = ExprMutator::VisitExpr_(cn); + + Call call = Downcast(new_e); + static auto fannotate = Op::GetAttr("target." + target_); + Op op = Downcast(call->op); + CHECK(op.defined()); + + if (fannotate.count(op)) { + bool external = fannotate[op](call->attrs, call->args); + if (external) { + tvm::Array compiler_begins; + for (const auto& it : call->args) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(it, target_); + compiler_begins.push_back(begin); + } + Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs); + const auto* end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(update_call, target_); + return end; + } + } else { + LOG(WARNING) << op->name << " in " << target_ + << " is not registered. It will be executed on CPU."; + } + return new_e; + } + + private: + std::string target_; +}; + +Expr AnnotateTarget(const Expr& expr, const std::string& target) { + return AnnotateTargetWrapper(target).Mutate(expr); +} + +} // namespace annotate_target + +namespace transform { + +Pass AnnotateTarget(const std::string& target) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::annotate_target::AnnotateTarget(f, target)); + }; + auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", + {tir::StringImmNode::make("InferType")}); + return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget") +.set_body_typed(AnnotateTarget); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py new file mode 100644 index 0000000..f4e602a --- /dev/null +++ b/tests/python/relay/test_annotate_target.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for annotating external targets.""" +import os +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform as transform +from tvm import relay +from tvm import runtime +from tvm.contrib import util + + +def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", + ctx=tvm.cpu(), params=None): + if sys.platform == "win32": + print("Skip test on Windows for now") + return + + def update_lib(lib): + test_dir = os.path.dirname( + os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.join(test_dir, "..", "..", "..") + contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") + + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = runtime.load_module(lib_path) + + return lib + + def check_vm_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target, params=params) + code, lib = exe.save() + lib = update_lib(lib) + exe = runtime.vm.Executable.load_exec(code, lib) + vm = runtime.vm.VirtualMachine(exe) + vm.init(ctx) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + def check_graph_runtime_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, param = relay.build(mod, target=target, params=params) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.set_input(**param) + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + check_vm_result() + check_graph_runtime_result() + + +def test_extern_dnnl(): + def annotated(dtype, ishape, w1shape): + data = relay.var('data', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data, weight1], out) + + mod = tvm.IRModule.from_expr(f) + return mod + + def expected(dtype, ishape, w1shape): + data = relay.var('data', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) + begin0 = relay.annotation.compiler_begin(data, "dnnl") + begin1 = relay.annotation.compiler_begin(weight1, "dnnl") + depthwise_conv2d_1 = relay.nn.conv2d(begin0, + begin1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") + begin2 = relay.annotation.compiler_begin(end0, "dnnl") + begin3 = relay.annotation.compiler_begin(end0, "dnnl") + begin4 = relay.annotation.compiler_begin(weight1, "dnnl") + depthwise_conv2d_2 = relay.nn.conv2d(begin3, + begin4, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") + begin5 = relay.annotation.compiler_begin(end1, "dnnl") + out = relay.add(begin2, begin5) + end2 = relay.annotation.compiler_end(out, "dnnl") + f = relay.Function([data, weight1], end2) + mod = tvm.IRModule.from_expr(f) + return mod + + dtype = "float32" + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + + def test_annotate(): + mod = annotated(dtype, ishape, w1shape) + mod = transform.AnnotateTarget("dnnl")(mod) + ref_mod = expected(dtype, ishape, w1shape) + assert relay.analysis.alpha_equal(mod, ref_mod) + + def test_run(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + ref_mod = annotated(dtype, ishape, w1shape) + mod = annotated(dtype, ishape, w1shape) + mod = transform.PartitionGraph()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data, w1_data) + + check_result(mod, {"data": i_data, "weight1": w1_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + test_annotate() + test_run() + + +def test_extern_dnnl_mobilenet(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 3, 224, 224) + mod, params = relay.testing.mobilenet.get_workload( + batch_size=1, dtype='float32') + + mod = transform.AnnotateTarget("dnnl")(mod) + mod = transform.PartitionGraph()(mod) + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, + dtype='float32') + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0)) + ref_res = ref_ex.evaluate()(i_data, **params) + + check_result(mod, {"data": i_data}, + (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) + + +if __name__ == "__main__": + test_extern_dnnl() + test_extern_dnnl_mobilenet() -- 2.7.4