* [Relay] Rewrite pass.
This pass transforms an expression to other expression.
This pass has many usecases
* Replace a expr to another expr, if the other expr has faster performance.
* For ASICs, we might want to modify the inputs to adapt to the HW support.
* Alter op layout can work in conjunction with this pass.
The supporting usecase is the Intel i8 x i8 conv. Intel HW supports u8 x i8 conv
in HW. Using this pass, we can replace an i8 x i8 conv to a sequence of
operators where one of the operators is now u8 x i8 conv. This will also help
automatic quantizaion performance.
* Better API name.
* Removing the conv2d legalization for x86. Will send a separate PR.
* Test name changes.
* Registering one funtion to register FTVMLegalize.
* Better comments.
*/
/*!
- * \file nnvm/compiler/op_attr_types.h
+ * \file tvm/relay/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
const Array<Tensor>& tinfos)>;
/*!
+ * \brief Legalizes an expression with another expression. This function will be
+ * invoked in Legalize pass. It is a target-dependent pass.
+ * \param attrs The attribute of the original node.
+ * \param inputs The input symbols of the original node.
+ * \param tinfos An array of placeholders, use for getting the inferred shape
+ * and dtype of the inputs.
+ * \return new_expr The modified expression.
+ */
+using FTVMLegalize = runtime::TypedPackedFunc<
+ Expr(const Attrs& attrs,
+ const Array<Expr>& args,
+ const Array<tvm::relay::Type>& arg_types)>;
+
+/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
TVM_DLL Pass AlterOpLayout();
/*!
+ * \brief Legalizes an expr with another expression.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass Legalize();
+
+/*!
* \brief Canonicalize cast expressions to make operator fusion more efficient.
*
* \return The pass.
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \
- register_pattern, register_alter_op_layout, schedule_injective, Op, OpPattern, debug
+ register_pattern, register_alter_op_layout, register_legalize, \
+ schedule_injective, Op, OpPattern, debug
# Operators
from .reduce import *
from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
+# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
+@reg.register_legalize("nn.conv2d")
+def legalize_conv2d(attrs, inputs, arg_dtypes):
+ return None
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
+def register_legalize(op_name, legal_op=None, level=10):
+ """Register legal transformation function for an op
+
+ Parameters
+ ----------
+ op_name : str
+ The name of the operator
+
+ legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
+ The function for transforming an expr to another expr.
+
+ level : int
+ The priority level
+ """
+ return register(op_name, "FTVMLegalize", legal_op, level)
+
+
def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.
return _transform.AlterOpLayout()
+def Legalize():
+ """Legalizes an expression with another expression.
+ This pass can be used to replace an expr with another expr for target
+ dependent optimizations. For example, one expr, though semnatically
+ equivalent to the other, can have better performance on a target. This pass
+ can be used to legalize the expr in a target-dependent manner.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that rewrites an expr.
+ """
+ return _transform.Legalize()
+
+
def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
+ // Legalize pass is restricted to homogeneous execution for now.
+ if (targets.size() == 1) {
+ pass_seqs.push_back(transform::Legalize());
+ }
+
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file legalize.cc
+ * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
+ * shape, dtype or layout to another op or a sequence of ops.
+ */
+
+#include <tvm/operation.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+namespace legalize {
+
+// Call registered FTVMLegalize of an op
+// Returns the legalized expression
+Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
+ static auto fop_legalize = Op::GetAttr<FTVMLegalize>("FTVMLegalize");
+ Op op = Downcast<Op>(ref_call->op);
+
+ Expr new_e;
+ bool modified = false;
+ if (fop_legalize.count(op)) {
+ tvm::Array<tvm::relay::Type> arg_types;
+ for (auto& expr : ref_call->args) {
+ arg_types.push_back(expr->checked_type());
+ }
+ Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
+ if (legalized_value.defined()) {
+ new_e = legalized_value;
+ modified = true;
+ }
+ }
+ if (!modified) {
+ new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
+ }
+
+ const CallNode* new_call = new_e.as<CallNode>();
+ CHECK(new_call) << "Can only replace the original operator with another call node";
+ return GetRef<Call>(new_call);
+}
+
+Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); }
+
+} // namespace legalize
+
+namespace transform {
+
+Pass Legalize() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(relay::legalize::Legalize(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test legalize pass"""
+import tvm
+
+from tvm import relay
+from tvm.relay.op import register_legalize
+from tvm.relay import transform, analysis
+
+
+def run_opt_pass(expr, passes):
+ passes = passes if isinstance(passes, list) else [passes]
+ mod = relay.Module.from_expr(expr)
+ seq = transform.Sequential(passes)
+ with transform.PassContext(opt_level=3):
+ mod = seq(mod)
+ entry = mod["main"]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+def test_legalize():
+ """Test directly replacing an operator with a new one"""
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(x, weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ y = relay.nn.relu(y)
+ y = relay.Function([x, weight], y)
+ return y
+
+ @register_legalize("nn.conv2d", level=100)
+ def legalize_conv2d(attrs, inputs, arg_types):
+ data, weight = inputs
+ weight = relay.multiply(weight, relay.const(2.0, "float32"))
+ return relay.nn.conv2d(data, weight, **attrs)
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ y = relay.nn.relu(y)
+ y = relay.Function([x, weight], y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+def test_legalize_none():
+ """Test doing nothing by returning 'None' """
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ y = relay.nn.global_max_pool2d(x)
+ y = relay.Function([x], y)
+ return y
+
+ called = [False]
+
+ @register_legalize("nn.global_max_pool2d", level=101)
+ def legalize_conv2d(attrs, inputs, arg_types):
+ called[0] = True
+ return None
+
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+
+ b = before()
+ b = run_opt_pass(b, transform.InferType())
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert(called[0])
+
+def test_legalize_multi_input():
+ """Test directly replacing an operator with a new one"""
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ y = relay.var("y", shape=(1, 64, 56, 20))
+ z = relay.var("z", shape=(1, 64, 56, 10))
+ func = relay.concatenate([x, y, z], axis=3)
+ func = relay.Function([x, y, z], func)
+ return func
+
+ @register_legalize("concatenate", level=100)
+ def legalize_concatenate(attrs, inputs, arg_types):
+ # Check that the correct multi-input case is handled.
+ assert len(inputs) == 1
+ assert isinstance(inputs[0], tvm.relay.expr.Tuple)
+ assert len(arg_types) == 1
+ assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
+ return None
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ y = relay.var("y", shape=(1, 64, 56, 20))
+ z = relay.var("z", shape=(1, 64, 56, 10))
+ func = relay.concatenate([x, y, z], axis=3)
+ func = relay.Function([x, y, z], func)
+ return func
+
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
+if __name__ == "__main__":
+ test_legalize()
+ test_legalize_none()
+ test_legalize_multi_input()