[Relay] Legalize pass (#3672)
authorAnimesh Jain <anijain@umich.edu>
Tue, 6 Aug 2019 22:23:41 +0000 (15:23 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Tue, 6 Aug 2019 22:23:41 +0000 (15:23 -0700)
* [Relay] Rewrite pass.

This pass transforms an expression to other expression.

This pass has many usecases
 * Replace a expr to another expr, if the other expr has faster performance.
 * For ASICs, we might want to modify the inputs to adapt to the HW support.
 * Alter op layout can work in conjunction with this pass.

The supporting usecase is the Intel i8 x i8 conv. Intel HW supports u8 x i8 conv
in HW. Using this pass, we can replace an i8 x i8 conv to a sequence of
operators where one of the operators is now u8 x i8 conv. This will also help
automatic quantizaion performance.

* Better API name.

* Removing the conv2d legalization for x86. Will send a separate PR.

* Test name changes.

* Registering one funtion to register FTVMLegalize.

* Better comments.

include/tvm/relay/op_attr_types.h
include/tvm/relay/transform.h
python/tvm/relay/op/__init__.py
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/op.py
python/tvm/relay/transform.py
src/relay/backend/build_module.cc
src/relay/pass/legalize.cc [new file with mode: 0644]
tests/python/relay/test_pass_legalize.py [new file with mode: 0644]

index 7709a79..c1a0f83 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file nnvm/compiler/op_attr_types.h
+ * \file tvm/relay/op_attr_types.h
  * \brief The Expr and related elements in DataFlow construction.
  */
 #ifndef TVM_RELAY_OP_ATTR_TYPES_H_
@@ -128,6 +128,20 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
        const Array<Tensor>& tinfos)>;
 
 /*!
+ * \brief Legalizes an expression with another expression. This function will be
+ *  invoked in Legalize pass. It is a target-dependent pass.
+ * \param attrs The attribute of the original node.
+ * \param inputs The input symbols of the original node.
+ * \param tinfos An array of placeholders, use for getting the inferred shape
+ *               and dtype of the inputs.
+ * \return new_expr The modified expression.
+ */
+using FTVMLegalize = runtime::TypedPackedFunc<
+  Expr(const Attrs& attrs,
+       const Array<Expr>& args,
+       const Array<tvm::relay::Type>& arg_types)>;
+
+/*!
  * \brief Forward rewriting rule for a specific op.
  *
  * \param ref_call The reference old call type to be rewritten.
index 56a5a3b..4bd5930 100644 (file)
@@ -521,6 +521,13 @@ TVM_DLL Pass CanonicalizeOps();
 TVM_DLL Pass AlterOpLayout();
 
 /*!
+ * \brief Legalizes an expr with another expression.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass Legalize();
+
+/*!
  * \brief Canonicalize cast expressions to make operator fusion more efficient.
  *
  * \return The pass.
index a27ab1d..b8ef4df 100644 (file)
@@ -18,7 +18,8 @@
 """Relay core operators."""
 # operator defs
 from .op import get, register, register_schedule, register_compute, register_gradient, \
-    register_pattern, register_alter_op_layout, schedule_injective, Op, OpPattern, debug
+    register_pattern, register_alter_op_layout, register_legalize, \
+    schedule_injective, Op, OpPattern, debug
 
 # Operators
 from .reduce import *
index 0c374b8..46ab69c 100644 (file)
@@ -204,6 +204,10 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
     from ... import op
     return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
 
+# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
+@reg.register_legalize("nn.conv2d")
+def legalize_conv2d(attrs, inputs, arg_dtypes):
+    return None
 
 reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
index 906bf25..e07d153 100644 (file)
@@ -170,6 +170,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
     return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
 
 
+def register_legalize(op_name, legal_op=None, level=10):
+    """Register legal transformation function for an op
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the operator
+
+    legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
+        The function for transforming an expr to another expr.
+
+    level : int
+        The priority level
+    """
+    return register(op_name, "FTVMLegalize", legal_op, level)
+
+
 def register_pattern(op_name, pattern, level=10):
     """Register operator pattern for an op.
 
index 2e64d14..46543af 100644 (file)
@@ -437,6 +437,21 @@ def AlterOpLayout():
     return _transform.AlterOpLayout()
 
 
+def Legalize():
+    """Legalizes an expression with another expression.
+    This pass can be used to replace an expr with another expr for target
+    dependent optimizations. For example, one expr, though semnatically
+    equivalent to the other, can have better performance on a target. This pass
+    can be used to legalize the expr in a target-dependent manner.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that rewrites an expr.
+    """
+    return _transform.Legalize()
+
+
 def RewriteAnnotatedOps(fallback_device):
     """Rewrite the annotated program where annotation operators, e.g.
     `on_deivce`, mark which device an expression should be scheduled to.
index 5281532..f757dad 100644 (file)
@@ -304,6 +304,11 @@ class RelayBuildModule : public runtime::ModuleNode {
     pass_seqs.push_back(transform::CanonicalizeCast());
     pass_seqs.push_back(transform::CanonicalizeOps());
 
+    // Legalize pass is restricted to homogeneous execution for now.
+    if (targets.size() == 1) {
+      pass_seqs.push_back(transform::Legalize());
+    }
+
     // Alter layout transformation is only applied to homogeneous execution yet.
     if (targets.size() == 1) {
       pass_seqs.push_back(transform::AlterOpLayout());
diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc
new file mode 100644 (file)
index 0000000..c041cb9
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file legalize.cc
+ * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
+ * shape, dtype or layout to another op or a sequence of ops.
+ */
+
+#include <tvm/operation.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+namespace legalize {
+
+// Call registered FTVMLegalize of an op
+// Returns the legalized expression
+Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
+  static auto fop_legalize = Op::GetAttr<FTVMLegalize>("FTVMLegalize");
+  Op op = Downcast<Op>(ref_call->op);
+
+  Expr new_e;
+  bool modified = false;
+  if (fop_legalize.count(op)) {
+    tvm::Array<tvm::relay::Type> arg_types;
+    for (auto& expr : ref_call->args) {
+      arg_types.push_back(expr->checked_type());
+    }
+    Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
+    if (legalized_value.defined()) {
+      new_e = legalized_value;
+      modified = true;
+    }
+  }
+  if (!modified) {
+    new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
+  }
+
+  const CallNode* new_call = new_e.as<CallNode>();
+  CHECK(new_call) << "Can only replace the original operator with another call node";
+  return GetRef<Call>(new_call);
+}
+
+Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); }
+
+}  // namespace legalize
+
+namespace transform {
+
+Pass Legalize() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+      [=](Function f, Module m, PassContext pc) {
+        return Downcast<Function>(relay::legalize::Legalize(f));
+      };
+  return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py
new file mode 100644 (file)
index 0000000..364d6b4
--- /dev/null
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test legalize pass"""
+import tvm
+
+from tvm import relay
+from tvm.relay.op import register_legalize
+from tvm.relay import transform, analysis
+
+
+def run_opt_pass(expr, passes):
+    passes = passes if isinstance(passes, list) else [passes]
+    mod = relay.Module.from_expr(expr)
+    seq = transform.Sequential(passes)
+    with transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+def test_legalize():
+    """Test directly replacing an operator with a new one"""
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.nn.relu(y)
+        y = relay.Function([x, weight], y)
+        return y
+
+    @register_legalize("nn.conv2d", level=100)
+    def legalize_conv2d(attrs, inputs, arg_types):
+        data, weight = inputs
+        weight = relay.multiply(weight, relay.const(2.0, "float32"))
+        return relay.nn.conv2d(data, weight, **attrs)
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.nn.relu(y)
+        y = relay.Function([x, weight], y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.Legalize())
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+def test_legalize_none():
+    """Test doing nothing by returning 'None' """
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        y = relay.nn.global_max_pool2d(x)
+        y = relay.Function([x], y)
+        return y
+
+    called = [False]
+
+    @register_legalize("nn.global_max_pool2d", level=101)
+    def legalize_conv2d(attrs, inputs, arg_types):
+        called[0] = True
+        return None
+
+    a = before()
+    a = run_opt_pass(a, transform.Legalize())
+
+    b = before()
+    b = run_opt_pass(b, transform.InferType())
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert(called[0])
+
+def test_legalize_multi_input():
+    """Test directly replacing an operator with a new one"""
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        y = relay.var("y", shape=(1, 64, 56, 20))
+        z = relay.var("z", shape=(1, 64, 56, 10))
+        func = relay.concatenate([x, y, z], axis=3)
+        func = relay.Function([x, y, z], func)
+        return func
+
+    @register_legalize("concatenate", level=100)
+    def legalize_concatenate(attrs, inputs, arg_types):
+        # Check that the correct multi-input case is handled.
+        assert len(inputs) == 1
+        assert isinstance(inputs[0], tvm.relay.expr.Tuple)
+        assert len(arg_types) == 1
+        assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
+        return None
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        y = relay.var("y", shape=(1, 64, 56, 20))
+        z = relay.var("z", shape=(1, 64, 56, 10))
+        func = relay.concatenate([x, y, z], axis=3)
+        func = relay.Function([x, y, z], func)
+        return func
+
+    a = before()
+    a = run_opt_pass(a, transform.Legalize())
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
+if __name__ == "__main__":
+    test_legalize()
+    test_legalize_none()
+    test_legalize_multi_input()