[Relay][QNN] QNNtoRelay & QNNLegalize Pass utility using Relay Legalize API. (#3838)
authorAnimesh Jain <anijain@umich.edu>
Fri, 30 Aug 2019 16:10:25 +0000 (09:10 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 30 Aug 2019 16:10:25 +0000 (09:10 -0700)
16 files changed:
include/tvm/relay/transform.h
python/tvm/relay/qnn/__init__.py
python/tvm/relay/qnn/op/__init__.py
python/tvm/relay/qnn/op/op.py [new file with mode: 0644]
python/tvm/relay/qnn/transform.py [new file with mode: 0644]
python/tvm/relay/transform.py
src/relay/pass/legalize.cc
src/relay/qnn/op/dequantize.cc
src/relay/qnn/op/quantize.cc
src/relay/qnn/op/requantize.cc
tests/python/relay/test_pass_legalize.py
tests/python/relay/test_pass_qnn_legalize.py [new file with mode: 0644]
tests/python/relay/test_qnn_concatenate.py
tests/python/relay/test_qnn_dequantize.py
tests/python/relay/test_qnn_quantize.py
tests/python/relay/test_qnn_requantize.py

index 4bd5930..428e5b6 100644 (file)
@@ -522,10 +522,15 @@ TVM_DLL Pass AlterOpLayout();
 
 /*!
  * \brief Legalizes an expr with another expression.
+ * \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
+ * One can collect and isolate similar type of legalize transformations using this param. For
+ * example, transformations that only apply to Dialects can be isolated into a FTVMDialectLegalize
+ * string. This pass calls only those transformations that have been registered using the supplied
+ * legalize_map_attr_name.
  *
  * \return The pass.
  */
-TVM_DLL Pass Legalize();
+TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize");
 
 /*!
  * \brief Canonicalize cast expressions to make operator fusion more efficient.
index a472109..fa888d7 100644 (file)
@@ -18,3 +18,4 @@
 """QNN dialect operators and IR passes."""
 from __future__ import absolute_import as _abs
 from . import op
+from . import transform
index e9adfa7..6a230e0 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=wildcard-import
-"""Neural network related operators."""
+"""QNN dialect related operators."""
 from __future__ import absolute_import as _abs
 from .qnn import *
+from .op import register_qnn_legalize
diff --git a/python/tvm/relay/qnn/op/op.py b/python/tvm/relay/qnn/op/op.py
new file mode 100644 (file)
index 0000000..505f047
--- /dev/null
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument
+"""The register functions for the QNN dialect."""
+from tvm.relay.op.op import register as register
+
+def register_qnn_legalize(op_name, legal_op=None, level=10):
+    """Register legal transformation function for a QNN op
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the operator
+
+    legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
+        The function for transforming an expr to another expr.
+
+    level : int
+        The priority level
+    """
+    return register(op_name, "FTVMQnnLegalize", legal_op, level)
diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
new file mode 100644 (file)
index 0000000..22e8f7f
--- /dev/null
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring
+"""
+QNN pass transformation infrastructure.
+"""
+from tvm import relay
+
+def CanonicalizeOps():
+    """Converts/Lowers an expression containing QNN ops to an expression containing only core
+    (non-Dialect) Relay ops. Each QNN op is lowered to a sequence of exisiting Relay ops. This is a
+    target-independent pass. One can register the lowering/transformation function for this op using
+    FTVMQnnCanonicalize attr_name for FTVMLegalize op attribute.  An example of this transformation
+    is below
+
+    Examples
+    ________
+
+    .. code-block:: python
+
+        # Original expression
+        qnn_expr = relay.qnn.op.requantize(y,
+                                           input_scale=1,
+                                           input_zero_point=0,
+                                           output_scale=1,
+                                           output_zero_point=0,
+                                           out_dtype='int8')
+
+        # We want to utilize all the existing Relay infrastucture. So, instead of supporting this
+        # QNN requantize op, we convert it into a sequence of existing Relay operators.
+        mod = relay.Module.from_expr(qnn_expr)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
+        relay_expr = mod['main']
+        print(relay_expr)
+
+        def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {
+          %0 = cast(%quantized_data, dtype="int64") /* ty=Tensor[(200), int64] */;
+          %1 = multiply(%0, 2 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
+          %2 = multiply(%1, 1073741824 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
+          %3 = add(%2, 1073741824 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
+          %4 = right_shift(%3, 31 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
+          %5 = add(0 /* ty=int64 */, %4) /* ty=Tensor[(200), int64] */;
+          %6 = clip(%5, a_min=-128f, a_max=127f) /* ty=Tensor[(200), int64] */;
+          cast(%6, dtype="int8") /* ty=Tensor[(200), int8] */
+        }
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that canonicalizes QNN ops to Relay ops.
+    """
+
+    return relay.transform.Legalize("FTVMQnnCanonicalize")
+
+
+def Legalize():
+    """Legalizes QNN ops. As opposed to Relay Legalize, this one legalizes only QNN ops. One can
+    register a transformation/legalization function for an op by using the FTVMQnnLegalize attr_name
+    for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize gives us separation of
+    concerns, leading to a better software practice. The legalization can be configured to happen
+    per target. An example of this type of legalization is shown below.
+
+    Examples
+    ________
+
+    Suppose the original graph is as follows
+
+            data(u8)  weight(u8)
+                |       |
+                |       |
+               qnn.conv2d (int32)
+                   |
+                   |
+                nn.relu (int32)
+
+    Now, we know that Intel Cascade Lake has VNNI instructions to speedup convolution. However, it
+    only works on u8 x i8 inputs. So, here, we can use QNN Legalize to transform the above graph as
+    follows
+
+            data(u8)  weight(u8)
+               |          |
+               |          |
+               |     requantize(i8)
+               |        |
+               |        |
+               qnn.conv2d (int32)
+                   |
+                   |
+                 nn.relu (int32)
+
+    In this legalization, since we have isolated legalization for QNN ops, it will only trigger the
+    transformation for qnn.conv2d (and not nn.relu). This pass can be followed by CanonicalizeOps to
+    further lower the qnn.requantize and qnn.conv2d into an expr containing only Relay ops.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that legalizes QNN ops.
+    """
+
+    return relay.transform.Legalize("FTVMQnnLegalize")
index ccdf00e..6c0038f 100644 (file)
@@ -414,19 +414,24 @@ def AlterOpLayout():
     return _transform.AlterOpLayout()
 
 
-def Legalize():
+def Legalize(legalize_map_attr_name="FTVMLegalize"):
     """Legalizes an expression with another expression.
     This pass can be used to replace an expr with another expr for target
     dependent optimizations. For example, one expr, though semnatically
     equivalent to the other, can have better performance on a target. This pass
     can be used to legalize the expr in a target-dependent manner.
 
+    Parameters
+    ----------
+    legalize_map_attr_name : str
+        The Op's attr name which corresponds to the legalize rule function.
+
     Returns
     -------
     ret : tvm.relay.Pass
         The registered pass that rewrites an expr.
     """
-    return _transform.Legalize()
+    return _transform.Legalize(legalize_map_attr_name)
 
 
 def RewriteAnnotatedOps(fallback_device):
index 0079dab..07b1d81 100644 (file)
@@ -25,6 +25,7 @@
  */
 
 #include <tvm/operation.h>
+#include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
 
@@ -35,48 +36,64 @@ namespace legalize {
 
 // Call registered FTVMLegalize of an op
 // Returns the legalized expression
-Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
-  static auto fop_legalize = Op::GetAttr<FTVMLegalize>("FTVMLegalize");
-  Op op = Downcast<Op>(ref_call->op);
-
-  Expr new_e;
-  bool modified = false;
-  if (fop_legalize.count(op)) {
-    // Collect input and output dtypes to pass on to Legalize API.
-    tvm::Array<tvm::relay::Type> types;
-    for (auto& expr : ref_call->args) {
-      types.push_back(expr->checked_type());
+class Legalizer : public ExprMutator {
+ public:
+  explicit Legalizer(const std::string& legalize_map_attr_name)
+      : legalize_map_attr_name_{legalize_map_attr_name} {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    // Get the new_call node without any changes to current call node.
+    Expr new_e = ExprMutator::VisitExpr_(call_node);
+    Call new_call = Downcast<Call>(new_e);
+
+    // Collect the registered legalize function.
+    auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
+    Op op = Downcast<Op>(call_node->op);
+
+    if (fop_legalize.count(op)) {
+      // Collect the new_args.
+      tvm::Array<Expr> call_args = new_call->args;
+
+      // Collect input and output dtypes to pass on to Legalize API.
+      tvm::Array<tvm::relay::Type> types;
+      for (auto arg : call_node->args) {
+        types.push_back(arg->checked_type());
+      }
+      types.push_back(call_node->checked_type());
+
+      // Transform the op by calling the registered legalize function.
+      Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
+
+      // Reassign new_e if the transformation succeeded.
+      if (legalized_value.defined()) {
+        // Check that the returned Expr from legalize is CallNode.
+        const CallNode* legalized_call_node = legalized_value.as<CallNode>();
+        CHECK(legalized_call_node)
+            << "Can only replace the original operator with another call node";
+
+        new_e = legalized_value;
+      }
     }
-    types.push_back(ref_call->checked_type());
 
-    // Transform the op by calling the registered legalize function.
-    Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
-
-    // Check if the transformation succeeded. If not, revert back to the original ref_call->op.
-    if (legalized_value.defined()) {
-      new_e = legalized_value;
-      modified = true;
-    }
-  }
-  if (!modified) {
-    new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
+    return new_e;
   }
 
-  const CallNode* new_call = new_e.as<CallNode>();
-  CHECK(new_call) << "Can only replace the original operator with another call node";
-  return GetRef<Call>(new_call);
-}
+ private:
+  std::string legalize_map_attr_name_;
+};
 
-Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); }
+Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
+  return Legalizer(legalize_map_attr_name).Mutate(expr);
+}
 
 }  // namespace legalize
 
 namespace transform {
 
-Pass Legalize() {
+Pass Legalize(const std::string& legalize_map_attr_name) {
   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
       [=](Function f, Module m, PassContext pc) {
-        return Downcast<Function>(relay::legalize::Legalize(f));
+        return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
       };
   return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
 }
index e42be2a..ff37e2d 100644 (file)
@@ -72,9 +72,9 @@ Expr DequantizeLower(const Expr& input_tensor,
   return scaled_output;
 }
 
-Expr DequantizeLegalize(const Attrs& attrs,
-                        const Array<Expr>& new_args,
-                        const Array<tvm::relay::Type>& types) {
+Expr DequantizeQnnCanonicalize(const Attrs& attrs,
+                               const Array<Expr>& new_args,
+                               const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& data = new_args[0];
   const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
@@ -93,7 +93,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv
 .add_argument("data", "Tensor", "The tensor to dequantize.")
 .set_support_level(11)
 .add_type_rel("Dequantize", DequantizeRel)
-.set_attr<FTVMLegalize>("FTVMLegalize", DequantizeLegalize);
+.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
 
 TVM_REGISTER_API("relay.qnn.op._make.dequantize")
 .set_body_typed(MakeDequantize);
index 675cd4c..97f1a22 100644 (file)
@@ -83,9 +83,9 @@ Expr QuantizeLower(const Expr& input_tensor,
   return clamp_out_dtype;
 }
 
-Expr QuantizeLegalize(const Attrs& attrs,
-                      const Array<Expr>& new_args,
-                      const Array<tvm::relay::Type>& types) {
+Expr QuantizeQnnCanonicalize(const Attrs& attrs,
+                             const Array<Expr>& new_args,
+                             const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& data = new_args[0];
   const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
@@ -111,7 +111,7 @@ scale and zero point.
 .add_argument("data", "Tensor", "The tensor to quantize.")
 .set_support_level(11)
 .add_type_rel("Quantize", QuantizeRel)
-.set_attr<FTVMLegalize>("FTVMLegalize", QuantizeLegalize);
+.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
 
 TVM_REGISTER_API("relay.qnn.op._make.quantize")
 .set_body_typed(MakeQuantize);
index ebc537e..448395a 100644 (file)
@@ -192,8 +192,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
  *
  * Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
  */
-Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
-                        const Array<tvm::relay::Type>& types) {
+Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
+                               const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& quantized_data = new_args[0];
   const auto* param = attrs.as<RequantizeAttrs>();
@@ -276,7 +276,7 @@ Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
 .add_argument("data", "Tensor", "The quantized input tensor.")
 .set_support_level(11)
 .add_type_rel("Requantize", RequantizeRel)
-.set_attr<FTVMLegalize>("FTVMLegalize", RequantizeLegalize);
+.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize);
 
 TVM_REGISTER_API("relay.qnn.op._make.requantize")
 .set_body_typed(MakeRequantize);
index 393c862..c5303ef 100644 (file)
@@ -92,6 +92,51 @@ def test_legalize_none():
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
     assert(called[0])
 
+def test_legalize_multiple_ops():
+    """Test directly replacing an operator with a new one"""
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.nn.relu(y)
+        y = relay.Function([x, weight], y)
+        return y
+
+    @register_legalize("nn.conv2d", level=102)
+    def legalize_conv2d(attrs, inputs, types):
+        data, weight = inputs
+        weight = relay.multiply(weight, relay.const(2.0, "float32"))
+        return relay.nn.conv2d(data, weight, **attrs)
+
+    @register_legalize("nn.relu", level=103)
+    def legalize_conv2d(attrs, inputs, types):
+        data = inputs[0]
+        add = relay.add(tvm.relay.const(0, "float32"), data)
+        return relay.nn.relu(add)
+
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.add(tvm.relay.const(0, "float32"), y)
+        y = relay.nn.relu(y)
+        y = relay.Function([x, weight], y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.Legalize())
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_legalize_multi_input():
     """Test directly replacing an operator with a new one"""
     def before():
@@ -102,7 +147,7 @@ def test_legalize_multi_input():
         func = relay.Function([x, y, z], func)
         return func
 
-    @register_legalize("concatenate", level=100)
+    @register_legalize("concatenate", level=104)
     def legalize_concatenate(attrs, inputs, types):
         # Check that the correct multi-input case is handled.
         assert len(inputs) == 1
@@ -153,7 +198,7 @@ def test_legalize_arm_layout_functional():
         func = relay.Function([data, kernel], y)
         return func
 
-    @register_legalize("nn.conv2d", level=101)
+    @register_legalize("nn.conv2d", level=105)
     def legalize_conv2d(attrs, inputs, types):
         from topi.arm_cpu.conv2d import _conv2d_legalize
         return _conv2d_legalize(attrs, inputs, types)
@@ -173,5 +218,6 @@ def test_legalize_arm_layout_functional():
 if __name__ == "__main__":
     test_legalize()
     test_legalize_none()
+    test_legalize_multiple_ops()
     test_legalize_multi_input()
     test_legalize_arm_layout_functional()
diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py
new file mode 100644 (file)
index 0000000..86769f8
--- /dev/null
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test legalize pass"""
+import numpy as np
+import tvm
+
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay.qnn.op import register_qnn_legalize
+from tvm.relay import transform, analysis
+
+
+def run_opt_pass(expr, passes):
+    passes = passes if isinstance(passes, list) else [passes]
+    mod = relay.Module.from_expr(expr)
+    seq = transform.Sequential(passes)
+    with transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+def test_qnn_legalize():
+    """Test directly replacing an operator with a new one"""
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
+        y = relay.qnn.op.requantize(x,
+                                    input_scale=1,
+                                    input_zero_point=0,
+                                    output_scale=1,
+                                    output_zero_point=0,
+                                    out_dtype='int8')
+        y = relay.Function([x], y)
+        return y
+
+    @register_qnn_legalize("qnn.requantize", level=100)
+    def legalize_qnn_requantize(attrs, inputs, types):
+        data = inputs[0]
+        data = relay.add(relay.const(0, 'int8'), data)
+        y = relay.qnn.op.requantize(data,
+                                    input_scale=1,
+                                    input_zero_point=0,
+                                    output_scale=1,
+                                    output_zero_point=0,
+                                    out_dtype='int8')
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
+        y = relay.add(relay.const(0, 'int8'), x)
+        z = relay.qnn.op.requantize(y,
+                                    input_scale=1,
+                                    input_zero_point=0,
+                                    output_scale=1,
+                                    output_zero_point=0,
+                                    out_dtype='int8')
+        z = relay.Function([x], z)
+        return z
+
+    a = before()
+
+    # Check that Relay Legalize does not change the graph.
+    a = run_opt_pass(a, relay.transform.Legalize())
+    b = run_opt_pass(before(), transform.InferType())
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+    # Check that QNN Legalize modifies the graph.
+    a = run_opt_pass(a, relay.qnn.transform.Legalize())
+    b = run_opt_pass(expected(), transform.InferType())
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+if __name__ == "__main__":
+    test_qnn_legalize()
index b0745cf..0a5f909 100644 (file)
@@ -41,7 +41,7 @@ def test_same_io_qnn_params():
     func = relay.Function([x, y], z)
     assert func.astext().count('requantize') == 0
     mod = relay.Module.from_expr(func)
-    mod = relay.transform.Legalize()(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
     golden_output = np.concatenate((x_data, y_data), axis=axis)
@@ -70,7 +70,7 @@ def test_different_io_qnn_params():
     func = relay.Function([x, y], z)
     assert func.astext().count('requantize') == 2
     mod = relay.Module.from_expr(func)
-    mod = relay.transform.Legalize()(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
     golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis)
@@ -99,7 +99,7 @@ def test_few_same_io_qnn_params():
     func = relay.Function([x, y], z)
     assert func.astext().count('requantize') == 1
     mod = relay.Module.from_expr(func)
-    mod = relay.transform.Legalize()(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
     golden_output = np.concatenate((x_data + 1, y_data), axis=axis)
@@ -128,7 +128,7 @@ def test_same_i_qnn_params():
     func = relay.Function([x, y], z)
     assert func.astext().count('requantize') == 1
     mod = relay.Module.from_expr(func)
-    mod = relay.transform.Legalize()(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
     func = mod["main"]
 
     golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis)
@@ -137,7 +137,6 @@ def test_same_i_qnn_params():
     op_res = intrp.evaluate(func)(x_data, y_data)
     np.testing.assert_equal(op_res.asnumpy(), golden_output)
 
-
 if __name__ == '__main__':
     test_same_io_qnn_params()
     test_different_io_qnn_params()
index a942980..76b61ae 100644 (file)
@@ -31,7 +31,7 @@ def test_dequantize_op():
                                                    input_zero_point=input_zero_point)
         mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
         mod = relay.Module.from_expr(mod)
-        mod = relay.transform.Legalize()(mod)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm", params=None)
             rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
index 47808cf..15319a7 100644 (file)
@@ -31,7 +31,7 @@ def test_quantize_op():
                                                  output_zero_point=output_zero_point,out_dtype=out_dtype)
         mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
         mod = relay.Module.from_expr(mod)
-        mod = relay.transform.Legalize()(mod)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm", params=None)
             rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
index 61f4289..2afa7d9 100644 (file)
@@ -49,7 +49,7 @@ def test_requantize():
 
         mod = relay.Function(relay.analysis.free_vars(mod), mod)
         mod = relay.Module.from_expr(mod)
-        mod = relay.transform.Legalize()(mod)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         return mod
 
     def same_scale_test():