[QNN] Requantize operator (#3531)
authorAnimesh Jain <anijain@umich.edu>
Thu, 8 Aug 2019 18:41:24 +0000 (11:41 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 8 Aug 2019 18:41:24 +0000 (11:41 -0700)
* [Relay] [Quantization] WIP - Common files for the qauntization work.

* [Relay] [Quantization] WIP - Prototyping requantize op.

* Requantize operator implementation.

Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - https://github.com/dmlc/tvm/issues/2351

Credit to TFLite and GemmLowp to provide reference implementations.

* Typo and lint fixes.

* Doc fix.

* Uncommenting the lint script (fixing mistake).

* Modifying the unit tests.

* Moving C++ files into src/relay/qnn

* Moving python files to python/tvm/relay/qnn. Some minor fixes.

* Moving the attrs.h inside the include directory.

* Pushing files that I forgot earlier. Changing util location.

* Incorporating comments. API change. Lint fixes.

* Modifying the GetFixedPointMultiplierShift API as per comments.

* Forgot the dialect change.

* Changing rewrite to qnn_lower.

* Renaming Quantize to Qnn for clarity.

* Remove use_int_domain.

* Incorportaing review comments.

* Adding API doc for QNN dialect.

* Move the qnn_lower pass to transform namespace.

* Moving from expr to module. Adding namespace in C++.

* Minor sentence rewrites. Added qnn namespace.

* Added the API doc.

* Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8.

* Style fixes. Better error messages.

* Adding documentation.

* More documentation fixes.

* Adding out dtype check for requantize.

* Adding corner case for FP32 to fixed point conversion.

* Adding extra line.

* Documentation fix.

* Adding static inline.

* Incorporating jackwish comment. Removed idtype from requantize lowering.

* Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32.

* Style fixes.

* Fix the docs.

* Move to Legalize API.

docs/langref/relay_op.rst
include/tvm/relay/qnn/attrs.h [new file with mode: 0644]
python/tvm/relay/__init__.py
python/tvm/relay/qnn/__init__.py [new file with mode: 0644]
python/tvm/relay/qnn/op/__init__.py [new file with mode: 0644]
python/tvm/relay/qnn/op/_make.py [new file with mode: 0644]
python/tvm/relay/qnn/op/qnn.py [new file with mode: 0644]
src/relay/pass/pattern_util.h
src/relay/qnn/op/requantize.cc [new file with mode: 0644]
src/relay/qnn/util.h [new file with mode: 0644]
tests/python/relay/test_qnn_requantize.py [new file with mode: 0644]

index 757fdac..6950ecc 100644 (file)
@@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary.
    tvm.relay.contrib.adaptive_avg_pool2d
 
 
+**Level 11: Dialect Operators**
+
+This level supports dialect operators.
+
+.. autosummary::
+   :nosignatures:
+
+   tvm.relay.qnn.op.requantize
+
+
 Level 1 Definitions
 -------------------
 .. autofunction:: tvm.relay.log
@@ -340,3 +350,8 @@ Level 10 Definitions
 .. autofunction:: tvm.relay.nn.batch_matmul
 .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
 .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
+
+
+Level 11 Definitions
+--------------------
+.. autofunction:: tvm.relay.qnn.op.requantize
diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h
new file mode 100644 (file)
index 0000000..e996028
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/qnn/attrs.h
+ * \brief Auxiliary attributes for qnn operators.
+ */
+#ifndef TVM_RELAY_QNN_ATTRS_H_
+#define TVM_RELAY_QNN_ATTRS_H_
+
+#include <tvm/attrs.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*! \brief Attribute for requantize operator */
+struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
+  double input_scale;
+  int32_t input_zero_point;
+  double output_scale;
+  int32_t output_zero_point;
+  std::string rounding;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
+    TVM_ATTR_FIELD(input_scale)
+        .describe("The scale of the input tensor.");
+    TVM_ATTR_FIELD(input_zero_point)
+        .describe("The zero point of the input tensor.");
+    TVM_ATTR_FIELD(output_scale)
+        .describe("The scale of the output tensor.");
+    TVM_ATTR_FIELD(output_zero_point)
+        .describe("The zero point of the output tensor.");
+    TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
+        .describe("Defines the rounding direction when the value is midway between"
+                  "two representable values. There are two supported modes - UPWARD"
+                  "or TONEAREST. Both modes behave exactly same except at the"
+                  "midpoints between the two representable values. At the midpoint,"
+                  "UPWARD rounds towards positive infinity (for example -1.5 will be"
+                  "rounded to -1). TONEAREST is the standard rounding where the"
+                  "value is rounded away from zero at midpoints (for example, -1.5"
+                  "rounds to -2). More context can be found at following gblic manual"
+                  "https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Output data type, set to explicit type under mixed precision setting");
+  }
+};
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_QNN_ATTRS_H_
index da14c80..01baa00 100644 (file)
@@ -53,6 +53,9 @@ from . import frontend
 from . import backend
 from . import quantize
 
+# Dialects
+from . import qnn
+
 from .scope_builder import ScopeBuilder
 
 # Span
diff --git a/python/tvm/relay/qnn/__init__.py b/python/tvm/relay/qnn/__init__.py
new file mode 100644 (file)
index 0000000..a472109
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""QNN dialect operators and IR passes."""
+from __future__ import absolute_import as _abs
+from . import op
diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py
new file mode 100644 (file)
index 0000000..e9adfa7
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Neural network related operators."""
+from __future__ import absolute_import as _abs
+from .qnn import *
diff --git a/python/tvm/relay/qnn/op/_make.py b/python/tvm/relay/qnn/op/_make.py
new file mode 100644 (file)
index 0000000..07b3dd1
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+from ...._ffi.function import _init_api
+
+_init_api("relay.qnn.op._make", __name__)
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
new file mode 100644 (file)
index 0000000..1717bc4
--- /dev/null
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=invalid-name
+"""QNN dialect operators."""
+
+from __future__ import absolute_import as _abs
+from . import _make
+
+def requantize(data,
+               input_scale,
+               input_zero_point,
+               output_scale,
+               output_zero_point,
+               rounding="TONEAREST",
+               out_dtype="int8"):
+    r"""Requantized operator.
+
+    The requantize operator converts one quantized tensor representation to
+    another quantized tensor representation. For the output tensor, we are
+    provided with output scale and zero point. The computation is as follows
+
+    Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    input_scale: float
+        The quantization scale for the input tensor.
+
+    input_zero_point: int
+        The zero point of the input tensor.
+
+    output_scale: float
+        The quantization scale for the output tensor.
+
+    output_zero_point: int
+        The zero point of the output tensor.
+
+    rounding : string, optional
+        Defines the rounding direction when the value is midway between two
+        representable values.
+
+    out_dtype : str, optional
+        Specifies the output data type.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _make.requantize(data,
+                            input_scale,
+                            input_zero_point,
+                            output_scale,
+                            output_zero_point,
+                            rounding,
+                            out_dtype)
index 7dcfd5c..3ccfff0 100644 (file)
@@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b
 }
 
 
+static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
+  static const Op& op = Op::Get("where");
+  return CallNode::make(op, {condition, x, y});
+}
+
+static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
+  static const Op& op = Op::Get("greater_equal");
+  return CallNode::make(op, {lhs, rhs}, Attrs(), {});
+}
+
+static inline Expr Full(Expr fill_value,
+                        Array<IndexExpr> shape,
+                        DataType dtype) {
+  auto attrs = make_node<InitOpAttrs>();
+  attrs->shape = std::move(shape);
+  attrs->dtype = std::move(dtype);
+  static const Op& op = Op::Get("full");
+  return CallNode::make(op, {fill_value}, Attrs(attrs), {});
+}
+
 Expr MakeConcatenate(Expr data, int axis);
 
 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
new file mode 100644 (file)
index 0000000..04f7e80
--- /dev/null
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file requantize.cc
+ * \brief QNN requantize operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include "../../pass/pattern_util.h"
+#include "../util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
+
+// Lowering of qnn.requantize op
+
+/*
+ * \brief Convert FP32 representation into fixed point representation.
+ * \param double_multplier The input FP32 number.
+ * \return The pair of multiplier and shift for fixed point representation.
+ * \note Converts a floating point number so that it can be represented by
+ *       integers. The representation is
+ *             float_number = (significand) * 2^(exponent)
+ *
+ *       The significand is a number between 0.5 and 1. This is represented by
+ *       an integer number. For example, if it is int32, then the decimal point
+ *       exists between bit 31 and 30 from LSB (or between first and second bit
+ *       from the left).
+ *
+ *       Some examples are
+ *           0.25 = (0.5) * 2^(-1)
+ *           0.125 = (0.5) * 2^(-2)
+ *
+ *       Credit to TFLite reference implementation.
+ */
+std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
+  int32_t significand, exponent;
+  if (double_multiplier == 0.) {
+    significand = 0;
+    exponent = 0;
+    return std::make_pair(significand, exponent);
+  }
+
+  // Get the significand and exponent.
+  double significand_d = std::frexp(double_multiplier, &exponent);
+
+  // Convert the double significand to int significand, i.e., convert into a
+  // integer where the decimal point is between bit 31 and 30. This is done by
+  // multiplying the double value with 2^31 and then casting to int.
+  significand_d = std::round(significand_d * (1ll << 31));
+  auto significand_int64 = static_cast<int64_t>(significand_d);
+  CHECK_LE(significand_int64, (1ll << 31));
+  if (significand_int64 == (1ll << 31)) {
+    significand_int64 /= 2;
+    ++exponent;
+  }
+  CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
+  significand = static_cast<int32_t>(significand_int64);
+  return std::make_pair(significand, exponent);
+}
+
+/*
+ * \brief Lower requantize to a sequence of ops.
+ * \param input_tensor The input tensor to requantize op.
+ * \param param The requantize op attrs.
+ * \param input_shape The input tensor shape of the requantize op.
+ * \return The sequence of existing Relay ops.
+ * \note Requantization using only integer computation. Here, the computation is
+ *       converted to a fixed point computation by computing output multiplier
+ *       and shift. This is useful, if the target device does not support/have
+ *       very expensive floating point computations.
+ *
+ *       Original compuation is scale_fp32 * quantized_tensor.  To convert into
+ *       integer computation, the multiplication with fp32 scalar can be
+ *       replaced by multiplication with an int value and then right shifting
+ *       the result. This approximates the floating point computation with a
+ *       fixed point computation.
+ *
+ *       The whole computation this can be broken down into following steps
+ *       1) Calculate the integer multiplier and integer shift.
+ *       2) Subtract the input integer zero point.
+ *       3) Multiply the fixed point multiplier with quantized tensor.
+ *       4) Round the result.
+ *       5) Right shift the result.
+ *       6) Add the output zero point.
+ *       7) Cast to the out_dtype.
+ */
+Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
+                     const Array<IndexExpr>& input_shape) {
+  double double_multiplier = param->input_scale / param->output_scale;
+
+  // Choose high precision datatype to be int64. This is for avoiding overflow
+  // in multiplication of two int32 values.
+  DataType hp_dtype = Int(64);
+
+  // 1) Calculating the integer multiplier and integer shift
+  int32_t fixed_point_multiplier, shift;
+  std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
+  int left_shift = shift > 0 ? shift : 0;
+  int right_shift = shift > 0 ? 0 : -shift;
+
+  // 2) Subtract the input_zero_point
+  auto tensor = Cast(input_tensor, hp_dtype);
+  if (param->input_zero_point != 0) {
+    auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
+    tensor = Subtract(tensor, input_zp);
+  }
+
+  // 3) Multiply the integer multiplier
+  if (left_shift != 0) {
+    tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
+  }
+  // Perform the multiplication in higher precision.
+  // The scalar is a fixed point value of int32 where the decimal point is
+  // between bits 31 and 30. After multiplying with input_tensor, the result is
+  // in int64 where the decimal point is sitting between bits 31 and 30 (from
+  // the right, rightmost bit is bit 0). The computation is performed in higher
+  // precision to avoid overflow in multiplying two int32 values.
+  Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
+  auto multiplied_t = Multiply(tensor, scalar);
+
+  // 4) Find the rounding scalar. This depends on where the final decimal point
+  // sits. As we will be right shifting the multiplied_t, we need to first
+  // calculate the total_right_shift.
+  int total_right_shift = right_shift + 31;
+  int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+
+  tensor = multiplied_t;
+  Expr round_scalar;
+  if (param->rounding == "UPWARD") {
+    round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
+  } else if (param->rounding == "TONEAREST") {
+    auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+    auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+    auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+    auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+
+    auto zero = MakeConstantScalar(hp_dtype, 0);
+    auto zero_t = Full(zero, input_shape, hp_dtype);
+    round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+  }
+  // Add the rounding scalar.
+  tensor = Add(tensor, round_scalar);
+
+  // 5) Simply right shift the result to get the final output.
+  auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+
+  // 6) Add the output zero point.
+  auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
+  auto shifted_int64_t = Add(output_zp, scaled_int64_t);
+
+  // 7) Clip to the out_dtype min/max.
+  auto q_min = GetQmin(param->out_dtype);
+  auto q_max = GetQmax(param->out_dtype);
+  auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
+  return Cast(clipped_t, param->out_dtype);
+}
+
+/*
+ * \brief Forward rewrite the requantize op.
+ * \param ref_call The original call that will be lowered.
+ * \param new_args The new mutated args to the call node.
+ * \param ctx The node context.
+ * \return The sequence of Relay ops for requantize op.
+ * \note Lowering of the requantize operation. The requantize operator converts
+ *       one quantized tensor to another quantized tensor. For the output
+ *       tensor, we are provided with output scale and zero point. The
+ *       computation looks like this
+ *
+ * Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+ */
+Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
+                        const Array<tvm::relay::Type>& arg_types) {
+  CHECK_EQ(new_args.size(), 1);
+  auto& quantized_data = new_args[0];
+  const auto* param = attrs.as<RequantizeAttrs>();
+  CHECK(param != nullptr);
+
+  // Find input shape.
+  CHECK_EQ(arg_types.size(), 1);
+  auto input_dtype = arg_types[0];
+  auto input_tensor_type = input_dtype.as<TensorTypeNode>();
+  CHECK(input_tensor_type != nullptr) << "Type information missing."
+                                      << " Please run infer_type pass.";
+  Array<IndexExpr> input_shape = input_tensor_type->shape;
+
+  // Check rounding validity.
+  CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
+      << "QNN requantize supports two rounding modes - UPWARD and "
+      << "TONEAREST";
+  return RequantizeLower(quantized_data, param, input_shape);
+}
+
+/*
+ * \brief Infer shape function of Requantize op.
+ * \param types The types of input args.
+ * \param num_inputs The number of inputs.
+ * \param attrs The op attributes.
+ * \param reporter The type reporter that sets the dtype and shapes.
+ * \return True if the infer shape succeeded.
+ */
+bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto in_dtype = data->dtype;
+  CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32))
+      << "Input type should be an integer but was " << in_dtype;
+
+  const Array<tvm::Expr> oshape = data->shape;
+  // assign output type
+  const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
+  auto out_dtype = param->out_dtype;
+  CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
+      << "Output type should be an integer but was " << out_dtype;
+  reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+// Positional relay function to create qnn requantize operator
+// used by frontend FFI.
+Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale,
+                    int32_t output_zero_point, std::string rounding, DataType out_dtype) {
+  auto attrs = make_node<RequantizeAttrs>();
+  attrs->input_scale = std::move(input_scale);
+  attrs->input_zero_point = std::move(input_zero_point);
+  attrs->output_scale = std::move(output_scale);
+  attrs->output_zero_point = std::move(output_zero_point);
+  attrs->rounding = std::move(rounding);
+  attrs->out_dtype = std::move(out_dtype);
+  static const Op& op = Op::Get("qnn.requantize");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+RELAY_REGISTER_OP("qnn.requantize")
+.describe(R"code(Requantize operator.
+The requantize operator converts one quantized tensor to another quantized
+tensor. For the output tensor, we are provided with output scale and zero
+point. The computation looks like this
+
+Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.RequantizeAttrs")
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The quantized input tensor.")
+.set_support_level(11)
+.add_type_rel("Requantize", RequantizeRel)
+.set_attr<FTVMLegalize>("FTVMLegalize", RequantizeLegalize);
+
+TVM_REGISTER_API("relay.qnn.op._make.requantize")
+.set_body_typed(MakeRequantize);
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h
new file mode 100644 (file)
index 0000000..1ada7ec
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/util.h
+ * \brief Utility methods needs for quantized ops that can be shared
+ */
+
+#ifndef TVM_RELAY_QNN_UTIL_H_
+#define TVM_RELAY_QNN_UTIL_H_
+
+#include <tvm/expr.h>
+#include <tvm/relay/expr.h>
+#include <limits>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+static inline const int32_t GetQmin(const DataType& dtype) {
+  CHECK_LE(dtype.bits(), 32)
+      << "QNN ops support int32 or lower precision";
+  if (dtype.is_int()) {
+    auto* min_value = as_const_int(dtype.min());
+    CHECK(min_value != nullptr);
+    return static_cast<int32_t>(min_value[0]);
+  } else if (dtype.is_uint()) {
+    auto* min_value = as_const_uint(dtype.min());
+    CHECK(min_value != nullptr);
+    return static_cast<int32_t>(min_value[0]);
+  } else {
+    LOG(FATAL) << "Type not supported " << dtype;
+    return -1;  // To hide the warning
+  }
+}
+
+static inline const int32_t GetQmax(const DataType& dtype) {
+  CHECK_LE(dtype.bits(), 32)
+      << "QNN ops support int32 or lower precision";
+  if (dtype.is_int()) {
+    auto* max_value = as_const_int(dtype.max());
+    CHECK(max_value != nullptr);
+    return static_cast<int32_t>(max_value[0]);
+  } else if (dtype.is_uint()) {
+    auto* max_value = as_const_uint(dtype.max());
+    CHECK(max_value != nullptr);
+    return static_cast<int32_t>(max_value[0]);
+  } else {
+    LOG(FATAL) << "Type not supported " << dtype;
+    return -1;  // To hide the warning
+  }
+}
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_QNN_UTIL_H_
diff --git a/tests/python/relay/test_qnn_requantize.py b/tests/python/relay/test_qnn_requantize.py
new file mode 100644 (file)
index 0000000..cd478fb
--- /dev/null
@@ -0,0 +1,259 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay.testing import create_workload
+from tvm.contrib import graph_runtime
+
+roundings = ["UPWARD", "TONEAREST"]
+
+def run_infer_type(expr):
+    mod = relay.Module.from_expr(expr)
+    mod = relay.transform.InferType()(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def test_requantize():
+    def verify(mod, goldens):
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(mod, "llvm", params=None)
+            golden_data, golden_output = goldens
+            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+            rt_mod.set_input("quantized_data",golden_data)
+            rt_mod.set_input(**params)
+            rt_mod.run()
+            res = rt_mod.get_output(0).asnumpy()
+            np.testing.assert_equal(res, golden_output)
+
+    def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
+            input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
+        quantized_data = relay.var("quantized_data", shape=data_shape,
+                dtype=data_dtype)
+        mod = relay.qnn.op.requantize(
+                quantized_data,
+                input_scale=input_scale,
+                input_zero_point=input_zero_point,
+                output_scale=output_scale,
+                output_zero_point=output_zero_point,
+                rounding=rounding,
+                out_dtype=out_dtype)
+
+        mod = relay.Function(relay.analysis.free_vars(mod), mod)
+        mod = relay.Module.from_expr(mod)
+        mod = relay.transform.Legalize()(mod)
+        return mod
+
+    def same_scale_test():
+        # Have same scales, everything within range
+        golden_data = np.arange(-100, 100, 1).astype('int32')
+        golden_output = golden_data
+
+        for rounding in roundings:
+            mod = get_mod(data_shape=(200, ),
+                          data_dtype='int32',
+                          out_dtype="int8",
+                          input_scale=0.5,
+                          output_scale=0.5,
+                          rounding=rounding)
+            verify(mod, (golden_data, golden_output))
+
+    def downscale_test():
+        for rounding in roundings:
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype='int8',
+                          input_scale=1,
+                          output_scale=16,
+                          rounding=rounding)
+
+            # Try positive values
+            # 8 corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative values
+            # -8 corresponds to -0.5. For UPWARD, this is 0
+            golden_data = np.arange(0, -32, -1).astype('int32')
+            if rounding == "UPWARD":
+                golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+            else:
+                golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+            verify(mod, (golden_data, golden_output))
+
+            # Try a different scale
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype="int8",
+                          input_scale=1,
+                          output_scale=4,
+                          rounding=rounding)
+
+            # Try positive values
+            # 2I corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
+                                      [2, 4, 4, 4, 4, 4, 4, 4, 2])
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative values
+            # -8 corresponds to -0.5. For UPWARD, this is 0
+            golden_data = np.arange(0, -32, -1).astype('int32')
+            if rounding == "UPWARD":
+                golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+                                          [3, 4, 4, 4, 4, 4, 4, 4, 1])
+            else:
+                golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+                                          [2, 4, 4, 4, 4, 4, 4, 4, 2])
+            verify(mod, (golden_data, golden_output))
+
+            # Try uint8 out_dtype
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype='uint8',
+                          input_scale=1,
+                          output_scale=16,
+                          rounding=rounding)
+
+            # Try positive values
+            # 8 corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+            verify(mod, (golden_data, golden_output))
+
+            # Try uint8 in_dtyope and uint8 out_dtype
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='uint8',
+                          out_dtype='uint8',
+                          input_scale=1,
+                          output_scale=16,
+                          rounding=rounding)
+
+            # Try positive values
+            # 8 corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+            verify(mod, (golden_data, golden_output))
+
+    def upscale_test():
+        for rounding in roundings:
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype="int8",
+                          input_scale=2,
+                          output_scale=1,
+                          rounding=rounding)
+
+            # Try positive values
+            # 8 corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.multiply(2, golden_data)
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative values
+            # -8 corresponds to -0.5. For UPWARD, this is 0
+            golden_data = np.arange(0, -32, -1).astype('int32')
+            golden_output = np.multiply(2, golden_data)
+            verify(mod, (golden_data, golden_output))
+
+    def saturation_test():
+        for rounding in roundings:
+            mod = get_mod(data_shape=(16, ),
+                          data_dtype='int32',
+                          out_dtype="int8",
+                          input_scale=0.5,
+                          output_scale=0.5,
+                          rounding=rounding)
+            golden_data = np.arange(0, 16, 1).astype('int32')
+            golden_data = np.add(120, golden_data)
+            output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
+                               127, 127, 127, 127, 127, 127, 127, 127])
+            golden_output = output
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative numbers
+            golden_data = np.arange(0, -16, -1).astype('int32')
+            golden_data = np.add(-120, golden_data)
+            output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
+                               -128, -128, -128, -128, -128, -128, -128, -128])
+            golden_output = output
+            verify(mod, (golden_data, golden_output))
+
+    def zero_point_test():
+        # Output zero point
+        for rounding in roundings:
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype='int8',
+                          input_scale=1,
+                          output_scale=16,
+                          output_zero_point=1,
+                          rounding=rounding)
+
+            # Try positive values
+            # 8 corresponds to 0.5, resulting in 1
+            golden_data = np.arange(0, 32, 1).astype('int32')
+            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+            golden_output = np.add(1, golden_output)
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative values
+            # -8 corresponds to -0.5. For UPWARD, this is 0
+            golden_data = np.arange(-32, -64, -1).astype('int32')
+            if rounding == "UPWARD":
+                golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+            else:
+                golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+            golden_output = np.add(1, golden_output)
+            verify(mod, (golden_data, golden_output))
+
+        # Input zero point
+        for rounding in roundings:
+            mod = get_mod(data_shape=(32, ),
+                          data_dtype='int32',
+                          out_dtype='int8',
+                          input_scale=1,
+                          output_scale=16,
+                          input_zero_point=16,
+                          rounding=rounding)
+
+            # Try positive values
+            golden_data = np.arange(32, 64, 1).astype('int32')
+            golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+            golden_output = np.subtract(golden_output, 1)
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative values
+            golden_data = np.arange(-32, -64, -1).astype('int32')
+            if rounding == "UPWARD":
+                golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+            else:
+                golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+            golden_output = np.subtract(golden_output, 1)
+            verify(mod, (golden_data, golden_output))
+
+    same_scale_test()
+    downscale_test()
+    upscale_test()
+    saturation_test()
+    zero_point_test()
+
+if __name__ == "__main__":
+    test_requantize()