[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to read from out_ty...
authorAnimesh Jain <anijain@umich.edu>
Fri, 23 Aug 2019 04:50:00 +0000 (21:50 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Fri, 23 Aug 2019 04:50:00 +0000 (12:50 +0800)
python/tvm/relay/op/nn/_nn.py
src/relay/pass/legalize.cc
src/relay/qnn/op/dequantize.cc
src/relay/qnn/op/quantize.cc
src/relay/qnn/op/requantize.cc
tests/python/relay/test_pass_legalize.py
topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/nn/conv2d.py

index 9b4caa1..7ceb272 100644 (file)
@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
     return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
 
 @reg.register_legalize("nn.conv2d")
-def legalize_conv2d(attrs, inputs, arg_dtypes):
-    """Legalize conv2d"""
-    from ... import op
-    return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
+def legalize_conv2d(attrs, inputs, types):
+    """Legalize conv2d op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.conv2d_legalize(attrs, inputs, types)
 
 reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
index c041cb9..0079dab 100644 (file)
@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
   Expr new_e;
   bool modified = false;
   if (fop_legalize.count(op)) {
-    tvm::Array<tvm::relay::Type> arg_types;
+    // Collect input and output dtypes to pass on to Legalize API.
+    tvm::Array<tvm::relay::Type> types;
     for (auto& expr : ref_call->args) {
-      arg_types.push_back(expr->checked_type());
+      types.push_back(expr->checked_type());
     }
-    Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
+    types.push_back(ref_call->checked_type());
+
+    // Transform the op by calling the registered legalize function.
+    Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
+
+    // Check if the transformation succeeded. If not, revert back to the original ref_call->op.
     if (legalized_value.defined()) {
       new_e = legalized_value;
       modified = true;
index 1e59440..e42be2a 100644 (file)
@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,
 
 Expr DequantizeLegalize(const Attrs& attrs,
                         const Array<Expr>& new_args,
-                        const Array<tvm::relay::Type>& arg_types) {
+                        const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& data = new_args[0];
   const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
   CHECK(dequantize_attrs != nullptr);
-  CHECK_EQ(arg_types.size(), 1);
+  CHECK_EQ(types.size(), 2);
   return DequantizeLower(data, dequantize_attrs);
 }
 
index 2f49400..675cd4c 100644 (file)
@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,
 
 Expr QuantizeLegalize(const Attrs& attrs,
                       const Array<Expr>& new_args,
-                      const Array<tvm::relay::Type>& arg_types) {
+                      const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& data = new_args[0];
   const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
   CHECK(quantize_attrs != nullptr);
 
-  CHECK_EQ(arg_types.size(), 1);
+  CHECK_EQ(types.size(), 2);
   return QuantizeLower(data, quantize_attrs);
 }
 
index e3052b7..ebc537e 100644 (file)
@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
  *       7) Cast to the out_dtype.
  */
 Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
-                     const Array<IndexExpr>& input_shape) {
+                     const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
   double double_multiplier = param->input_scale / param->output_scale;
 
   // Choose high precision datatype to be int64. This is for avoiding overflow
@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
   auto shifted_int64_t = Add(output_zp, scaled_int64_t);
 
   // 7) Clip to the out_dtype min/max.
-  auto q_min = GetQmin(param->out_dtype);
-  auto q_max = GetQmax(param->out_dtype);
+  auto q_min = GetQmin(out_dtype);
+  auto q_max = GetQmax(out_dtype);
   auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
-  return Cast(clipped_t, param->out_dtype);
+  return Cast(clipped_t, out_dtype);
 }
 
 /*
@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
  * Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
  */
 Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
-                        const Array<tvm::relay::Type>& arg_types) {
+                        const Array<tvm::relay::Type>& types) {
   CHECK_EQ(new_args.size(), 1);
   auto& quantized_data = new_args[0];
   const auto* param = attrs.as<RequantizeAttrs>();
   CHECK(param != nullptr);
 
   // Find input shape.
-  CHECK_EQ(arg_types.size(), 1);
-  auto input_dtype = arg_types[0];
-  auto input_tensor_type = input_dtype.as<TensorTypeNode>();
-  CHECK(input_tensor_type != nullptr) << "Type information missing."
-                                      << " Please run infer_type pass.";
-  Array<IndexExpr> input_shape = input_tensor_type->shape;
+  CHECK_EQ(types.size(), 2);
+  auto in_type = types[0];
+  auto in_tensor_type = in_type.as<TensorTypeNode>();
+  CHECK(in_tensor_type != nullptr) << "Type information missing."
+                                   << " Please run infer_type pass.";
+  Array<IndexExpr> input_shape = in_tensor_type->shape;
+
+  // Find the output dtype.
+  auto out_type = types[1];
+  auto out_tensor_type = out_type.as<TensorTypeNode>();
+  CHECK(out_tensor_type != nullptr) << "Type information missing."
+                                    << " Please run infer_type pass.";
+  auto out_dtype = out_tensor_type->dtype;
 
   // Check rounding validity.
   CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
       << "QNN requantize supports two rounding modes - UPWARD and "
       << "TONEAREST";
-  return RequantizeLower(quantized_data, param, input_shape);
+  return RequantizeLower(quantized_data, param, input_shape, out_dtype);
 }
 
 /*
@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
 tensor. For the output tensor, we are provided with output scale and zero
 point. The computation looks like this
 
-Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
 
 )code" TVM_ADD_FILELINE)
 .set_attrs_type_key("relay.attrs.RequantizeAttrs")
index 52deeb5..393c862 100644 (file)
@@ -47,7 +47,7 @@ def test_legalize():
         return y
 
     @register_legalize("nn.conv2d", level=100)
-    def legalize_conv2d(attrs, inputs, arg_types):
+    def legalize_conv2d(attrs, inputs, types):
         data, weight = inputs
         weight = relay.multiply(weight, relay.const(2.0, "float32"))
         return relay.nn.conv2d(data, weight, **attrs)
@@ -80,7 +80,7 @@ def test_legalize_none():
     called = [False]
 
     @register_legalize("nn.global_max_pool2d", level=101)
-    def legalize_conv2d(attrs, inputs, arg_types):
+    def legalize_conv2d(attrs, inputs, types):
         called[0] = True
         return None
 
@@ -103,12 +103,13 @@ def test_legalize_multi_input():
         return func
 
     @register_legalize("concatenate", level=100)
-    def legalize_concatenate(attrs, inputs, arg_types):
+    def legalize_concatenate(attrs, inputs, types):
         # Check that the correct multi-input case is handled.
         assert len(inputs) == 1
         assert isinstance(inputs[0], tvm.relay.expr.Tuple)
-        assert len(arg_types) == 1
-        assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
+        assert len(types) == 2
+        assert isinstance(types[0], tvm.relay.ty.TupleType)
+        assert isinstance(types[1], tvm.relay.ty.TensorType)
         return None
 
     def expected():
@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional():
         return func
 
     @register_legalize("nn.conv2d", level=101)
-    def legalize_conv2d(attrs, inputs, arg_types):
+    def legalize_conv2d(attrs, inputs, types):
         from topi.arm_cpu.conv2d import _conv2d_legalize
-        return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
+        return _conv2d_legalize(attrs, inputs, types)
 
     a = before()
     b = run_opt_pass(a, transform.Legalize())
index 95342b6..77b37ed 100644 (file)
 """Conv2D schedule for ARM CPU"""
 from __future__ import absolute_import as _abs
 
-import warnings
+import logging
 
 import tvm
 from tvm import autotvm
+from tvm import relay
 import tvm.contrib.nnpack
 
 from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
 
+logger = logging.getLogger('topi')
+
 @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
 def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """TOPI compute callback for conv2d
@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     if layout != 'NCHW':
         return None
     if dilation != (1, 1):
-        warnings.warn("Does not support weight pre-transform for dilated convolution.")
+        logger.warning("Does not support weight pre-transform for dilated convolution.")
         return None
 
     data, kernel = tinfos[0:2]
@@ -786,31 +789,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             return None
 
 @conv2d_legalize.register("arm_cpu")
-def _conv2d_legalize(attrs, inputs, arg_types, F):
-    if F.__name__ != 'tvm.relay.op':
-        return None
+def _conv2d_legalize(attrs, inputs, arg_types):
+    """Legalizes Conv2D op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
     if attrs['data_layout'] == 'NHWC':
         data, kernel = inputs
         if attrs['kernel_layout'] == 'HWIO':
             # Handle HWIO layout. This is common in TF graph.
-            kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
+            kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
         elif attrs['kernel_layout'] == 'HWOI':
             # Handle HWOI layout. This is common in TF depthwise conv2d graph.
-            kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
+            kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
         elif attrs['kernel_layout'] != 'OIHW':
             return None
 
-        warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
-                      + "fallback to NCHW. This can result in performance degradation.")
+        logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+                       + "fallback to NCHW. This can result in performance degradation.")
         # Set new attrs for the tranposed conv.
         new_attrs = {k: attrs[k] for k in attrs.keys()}
         new_attrs['data_layout'] = 'NCHW'
         new_attrs['kernel_layout'] = 'OIHW'
 
         # Convert from NHWC to NCHW.
-        data = F.transpose(data, axes=(0, 3, 1, 2))
-        conv = F.nn.conv2d(data, kernel, **new_attrs)
+        data = relay.transpose(data, axes=(0, 3, 1, 2))
+        conv = relay.nn.conv2d(data, kernel, **new_attrs)
         # Convert back to original NHWC layout.
-        out = F.transpose(conv, axes=(0, 2, 3, 1))
+        out = relay.transpose(conv, axes=(0, 2, 3, 1))
         return out
     return None
index e7ab7ba..05580c8 100644 (file)
@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
 
 
 @tvm.target.generic_func
-def conv2d_legalize(attrs, inputs, arg_dtypes, F):
+def conv2d_legalize(attrs, inputs, types):
     """Legalizes Conv2D op.
+
     Parameters
     ----------
-    attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+    attrs : tvm.attrs.Attrs
         Attributes of current convolution
     inputs : list of tvm.relay.Expr
-        The args of the Relay expr to be legalized.
-    arg_dtypes : list of types
-        List of types of input arguments
-    F: symbol
-        The context, can be either nnvm.sym or relay.op
-    Note
-    ----
-    Unlike other TOPI functions, this function operates on both graph level and operator level,
-    so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
     """
     # not to change by default
     return None