From 79922bd39fe3ca9e4f88a9b494e92bea69a5fe68 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 6 Aug 2019 15:23:41 -0700 Subject: [PATCH] [Relay] Legalize pass (#3672) * [Relay] Rewrite pass. This pass transforms an expression to other expression. This pass has many usecases * Replace a expr to another expr, if the other expr has faster performance. * For ASICs, we might want to modify the inputs to adapt to the HW support. * Alter op layout can work in conjunction with this pass. The supporting usecase is the Intel i8 x i8 conv. Intel HW supports u8 x i8 conv in HW. Using this pass, we can replace an i8 x i8 conv to a sequence of operators where one of the operators is now u8 x i8 conv. This will also help automatic quantizaion performance. * Better API name. * Removing the conv2d legalization for x86. Will send a separate PR. * Test name changes. * Registering one funtion to register FTVMLegalize. * Better comments. --- include/tvm/relay/op_attr_types.h | 16 +++- include/tvm/relay/transform.h | 7 ++ python/tvm/relay/op/__init__.py | 3 +- python/tvm/relay/op/nn/_nn.py | 4 + python/tvm/relay/op/op.py | 17 ++++ python/tvm/relay/transform.py | 15 ++++ src/relay/backend/build_module.cc | 5 ++ src/relay/pass/legalize.cc | 83 ++++++++++++++++++++ tests/python/relay/test_pass_legalize.py | 130 +++++++++++++++++++++++++++++++ 9 files changed, 278 insertions(+), 2 deletions(-) create mode 100644 src/relay/pass/legalize.cc create mode 100644 tests/python/relay/test_pass_legalize.py diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 7709a79..c1a0f83 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -18,7 +18,7 @@ */ /*! - * \file nnvm/compiler/op_attr_types.h + * \file tvm/relay/op_attr_types.h * \brief The Expr and related elements in DataFlow construction. */ #ifndef TVM_RELAY_OP_ATTR_TYPES_H_ @@ -128,6 +128,20 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< const Array& tinfos)>; /*! + * \brief Legalizes an expression with another expression. This function will be + * invoked in Legalize pass. It is a target-dependent pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos An array of placeholders, use for getting the inferred shape + * and dtype of the inputs. + * \return new_expr The modified expression. + */ +using FTVMLegalize = runtime::TypedPackedFunc< + Expr(const Attrs& attrs, + const Array& args, + const Array& arg_types)>; + +/*! * \brief Forward rewriting rule for a specific op. * * \param ref_call The reference old call type to be rewritten. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 56a5a3b..4bd5930 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -521,6 +521,13 @@ TVM_DLL Pass CanonicalizeOps(); TVM_DLL Pass AlterOpLayout(); /*! + * \brief Legalizes an expr with another expression. + * + * \return The pass. + */ +TVM_DLL Pass Legalize(); + +/*! * \brief Canonicalize cast expressions to make operator fusion more efficient. * * \return The pass. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a27ab1d..b8ef4df 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -18,7 +18,8 @@ """Relay core operators.""" # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ - register_pattern, register_alter_op_layout, schedule_injective, Op, OpPattern, debug + register_pattern, register_alter_op_layout, register_legalize, \ + schedule_injective, Op, OpPattern, debug # Operators from .reduce import * diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0c374b8..46ab69c 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -204,6 +204,10 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): from ... import op return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) +# A placeholder to have at least one invocation of register legalize to register FTVMLegalize. +@reg.register_legalize("nn.conv2d") +def legalize_conv2d(attrs, inputs, arg_dtypes): + return None reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 906bf25..e07d153 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -170,6 +170,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): return register(op_name, "FTVMAlterOpLayout", alter_layout, level) +def register_legalize(op_name, legal_op=None, level=10): + """Register legal transformation function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for transforming an expr to another expr. + + level : int + The priority level + """ + return register(op_name, "FTVMLegalize", legal_op, level) + + def register_pattern(op_name, pattern, level=10): """Register operator pattern for an op. diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 2e64d14..46543af 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -437,6 +437,21 @@ def AlterOpLayout(): return _transform.AlterOpLayout() +def Legalize(): + """Legalizes an expression with another expression. + This pass can be used to replace an expr with another expr for target + dependent optimizations. For example, one expr, though semnatically + equivalent to the other, can have better performance on a target. This pass + can be used to legalize the expr in a target-dependent manner. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that rewrites an expr. + """ + return _transform.Legalize() + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 5281532..f757dad 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -304,6 +304,11 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc new file mode 100644 index 0000000..c041cb9 --- /dev/null +++ b/src/relay/pass/legalize.cc @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file legalize.cc + * \brief Converts an expr to another expr. This pass can be used to transform an op based on its + * shape, dtype or layout to another op or a sequence of ops. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +namespace legalize { + +// Call registered FTVMLegalize of an op +// Returns the legalized expression +Expr Legalizer(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { + static auto fop_legalize = Op::GetAttr("FTVMLegalize"); + Op op = Downcast(ref_call->op); + + Expr new_e; + bool modified = false; + if (fop_legalize.count(op)) { + tvm::Array arg_types; + for (auto& expr : ref_call->args) { + arg_types.push_back(expr->checked_type()); + } + Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types); + if (legalized_value.defined()) { + new_e = legalized_value; + modified = true; + } + } + if (!modified) { + new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs); + } + + const CallNode* new_call = new_e.as(); + CHECK(new_call) << "Can only replace the original operator with another call node"; + return GetRef(new_call); +} + +Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); } + +} // namespace legalize + +namespace transform { + +Pass Legalize() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::legalize::Legalize(f)); + }; + return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py new file mode 100644 index 0000000..364d6b4 --- /dev/null +++ b/tests/python/relay/test_pass_legalize.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test legalize pass""" +import tvm + +from tvm import relay +from tvm.relay.op import register_legalize +from tvm.relay import transform, analysis + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + +def test_legalize(): + """Test directly replacing an operator with a new one""" + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + @register_legalize("nn.conv2d", level=100) + def legalize_conv2d(attrs, inputs, arg_types): + data, weight = inputs + weight = relay.multiply(weight, relay.const(2.0, "float32")) + return relay.nn.conv2d(data, weight, **attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")), + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + +def test_legalize_none(): + """Test doing nothing by returning 'None' """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + y = relay.nn.global_max_pool2d(x) + y = relay.Function([x], y) + return y + + called = [False] + + @register_legalize("nn.global_max_pool2d", level=101) + def legalize_conv2d(attrs, inputs, arg_types): + called[0] = True + return None + + a = before() + a = run_opt_pass(a, transform.Legalize()) + + b = before() + b = run_opt_pass(b, transform.InferType()) + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + assert(called[0]) + +def test_legalize_multi_input(): + """Test directly replacing an operator with a new one""" + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + y = relay.var("y", shape=(1, 64, 56, 20)) + z = relay.var("z", shape=(1, 64, 56, 10)) + func = relay.concatenate([x, y, z], axis=3) + func = relay.Function([x, y, z], func) + return func + + @register_legalize("concatenate", level=100) + def legalize_concatenate(attrs, inputs, arg_types): + # Check that the correct multi-input case is handled. + assert len(inputs) == 1 + assert isinstance(inputs[0], tvm.relay.expr.Tuple) + assert len(arg_types) == 1 + assert isinstance(arg_types[0], tvm.relay.ty.TupleType) + return None + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + y = relay.var("y", shape=(1, 64, 56, 20)) + z = relay.var("z", shape=(1, 64, 56, 10)) + func = relay.concatenate([x, y, z], axis=3) + func = relay.Function([x, y, z], func) + return func + + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +if __name__ == "__main__": + test_legalize() + test_legalize_none() + test_legalize_multi_input() -- 2.7.4