* [Relay] [Quantization] WIP - Common files for the qauntization work.
* [Relay] [Quantization] WIP - Prototyping requantize op.
* Requantize operator implementation.
Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features
- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
- Two rounding modes - FE_UPWARDS (round towards infinity) and
FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases
Relevant Issue - https://github.com/dmlc/tvm/issues/2351
Credit to TFLite and GemmLowp to provide reference implementations.
* Typo and lint fixes.
* Doc fix.
* Uncommenting the lint script (fixing mistake).
* Modifying the unit tests.
* Moving C++ files into src/relay/qnn
* Moving python files to python/tvm/relay/qnn. Some minor fixes.
* Moving the attrs.h inside the include directory.
* Pushing files that I forgot earlier. Changing util location.
* Incorporating comments. API change. Lint fixes.
* Modifying the GetFixedPointMultiplierShift API as per comments.
* Forgot the dialect change.
* Changing rewrite to qnn_lower.
* Renaming Quantize to Qnn for clarity.
* Remove use_int_domain.
* Incorportaing review comments.
* Adding API doc for QNN dialect.
* Move the qnn_lower pass to transform namespace.
* Moving from expr to module. Adding namespace in C++.
* Minor sentence rewrites. Added qnn namespace.
* Added the API doc.
* Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8.
* Style fixes. Better error messages.
* Adding documentation.
* More documentation fixes.
* Adding out dtype check for requantize.
* Adding corner case for FP32 to fixed point conversion.
* Adding extra line.
* Documentation fix.
* Adding static inline.
* Incorporating jackwish comment. Removed idtype from requantize lowering.
* Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32.
* Style fixes.
* Fix the docs.
* Move to Legalize API.
tvm.relay.contrib.adaptive_avg_pool2d
+**Level 11: Dialect Operators**
+
+This level supports dialect operators.
+
+.. autosummary::
+ :nosignatures:
+
+ tvm.relay.qnn.op.requantize
+
+
Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
+
+
+Level 11 Definitions
+--------------------
+.. autofunction:: tvm.relay.qnn.op.requantize
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/qnn/attrs.h
+ * \brief Auxiliary attributes for qnn operators.
+ */
+#ifndef TVM_RELAY_QNN_ATTRS_H_
+#define TVM_RELAY_QNN_ATTRS_H_
+
+#include <tvm/attrs.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*! \brief Attribute for requantize operator */
+struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
+ double input_scale;
+ int32_t input_zero_point;
+ double output_scale;
+ int32_t output_zero_point;
+ std::string rounding;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
+ TVM_ATTR_FIELD(input_scale)
+ .describe("The scale of the input tensor.");
+ TVM_ATTR_FIELD(input_zero_point)
+ .describe("The zero point of the input tensor.");
+ TVM_ATTR_FIELD(output_scale)
+ .describe("The scale of the output tensor.");
+ TVM_ATTR_FIELD(output_zero_point)
+ .describe("The zero point of the output tensor.");
+ TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
+ .describe("Defines the rounding direction when the value is midway between"
+ "two representable values. There are two supported modes - UPWARD"
+ "or TONEAREST. Both modes behave exactly same except at the"
+ "midpoints between the two representable values. At the midpoint,"
+ "UPWARD rounds towards positive infinity (for example -1.5 will be"
+ "rounded to -1). TONEAREST is the standard rounding where the"
+ "value is rounded away from zero at midpoints (for example, -1.5"
+ "rounds to -2). More context can be found at following gblic manual"
+ "https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
+ TVM_ATTR_FIELD(out_dtype)
+ .set_default(NullValue<DataType>())
+ .describe("Output data type, set to explicit type under mixed precision setting");
+ }
+};
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
+#endif // TVM_RELAY_QNN_ATTRS_H_
from . import backend
from . import quantize
+# Dialects
+from . import qnn
+
from .scope_builder import ScopeBuilder
# Span
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""QNN dialect operators and IR passes."""
+from __future__ import absolute_import as _abs
+from . import op
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Neural network related operators."""
+from __future__ import absolute_import as _abs
+from .qnn import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+from ...._ffi.function import _init_api
+
+_init_api("relay.qnn.op._make", __name__)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=invalid-name
+"""QNN dialect operators."""
+
+from __future__ import absolute_import as _abs
+from . import _make
+
+def requantize(data,
+ input_scale,
+ input_zero_point,
+ output_scale,
+ output_zero_point,
+ rounding="TONEAREST",
+ out_dtype="int8"):
+ r"""Requantized operator.
+
+ The requantize operator converts one quantized tensor representation to
+ another quantized tensor representation. For the output tensor, we are
+ provided with output scale and zero point. The computation is as follows
+
+ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator.
+
+ input_scale: float
+ The quantization scale for the input tensor.
+
+ input_zero_point: int
+ The zero point of the input tensor.
+
+ output_scale: float
+ The quantization scale for the output tensor.
+
+ output_zero_point: int
+ The zero point of the output tensor.
+
+ rounding : string, optional
+ Defines the rounding direction when the value is midway between two
+ representable values.
+
+ out_dtype : str, optional
+ Specifies the output data type.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+
+ return _make.requantize(data,
+ input_scale,
+ input_zero_point,
+ output_scale,
+ output_zero_point,
+ rounding,
+ out_dtype)
}
+static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
+ static const Op& op = Op::Get("where");
+ return CallNode::make(op, {condition, x, y});
+}
+
+static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
+ static const Op& op = Op::Get("greater_equal");
+ return CallNode::make(op, {lhs, rhs}, Attrs(), {});
+}
+
+static inline Expr Full(Expr fill_value,
+ Array<IndexExpr> shape,
+ DataType dtype) {
+ auto attrs = make_node<InitOpAttrs>();
+ attrs->shape = std::move(shape);
+ attrs->dtype = std::move(dtype);
+ static const Op& op = Op::Get("full");
+ return CallNode::make(op, {fill_value}, Attrs(attrs), {});
+}
+
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file requantize.cc
+ * \brief QNN requantize operator.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include "../../pass/pattern_util.h"
+#include "../util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
+
+// Lowering of qnn.requantize op
+
+/*
+ * \brief Convert FP32 representation into fixed point representation.
+ * \param double_multplier The input FP32 number.
+ * \return The pair of multiplier and shift for fixed point representation.
+ * \note Converts a floating point number so that it can be represented by
+ * integers. The representation is
+ * float_number = (significand) * 2^(exponent)
+ *
+ * The significand is a number between 0.5 and 1. This is represented by
+ * an integer number. For example, if it is int32, then the decimal point
+ * exists between bit 31 and 30 from LSB (or between first and second bit
+ * from the left).
+ *
+ * Some examples are
+ * 0.25 = (0.5) * 2^(-1)
+ * 0.125 = (0.5) * 2^(-2)
+ *
+ * Credit to TFLite reference implementation.
+ */
+std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
+ int32_t significand, exponent;
+ if (double_multiplier == 0.) {
+ significand = 0;
+ exponent = 0;
+ return std::make_pair(significand, exponent);
+ }
+
+ // Get the significand and exponent.
+ double significand_d = std::frexp(double_multiplier, &exponent);
+
+ // Convert the double significand to int significand, i.e., convert into a
+ // integer where the decimal point is between bit 31 and 30. This is done by
+ // multiplying the double value with 2^31 and then casting to int.
+ significand_d = std::round(significand_d * (1ll << 31));
+ auto significand_int64 = static_cast<int64_t>(significand_d);
+ CHECK_LE(significand_int64, (1ll << 31));
+ if (significand_int64 == (1ll << 31)) {
+ significand_int64 /= 2;
+ ++exponent;
+ }
+ CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
+ significand = static_cast<int32_t>(significand_int64);
+ return std::make_pair(significand, exponent);
+}
+
+/*
+ * \brief Lower requantize to a sequence of ops.
+ * \param input_tensor The input tensor to requantize op.
+ * \param param The requantize op attrs.
+ * \param input_shape The input tensor shape of the requantize op.
+ * \return The sequence of existing Relay ops.
+ * \note Requantization using only integer computation. Here, the computation is
+ * converted to a fixed point computation by computing output multiplier
+ * and shift. This is useful, if the target device does not support/have
+ * very expensive floating point computations.
+ *
+ * Original compuation is scale_fp32 * quantized_tensor. To convert into
+ * integer computation, the multiplication with fp32 scalar can be
+ * replaced by multiplication with an int value and then right shifting
+ * the result. This approximates the floating point computation with a
+ * fixed point computation.
+ *
+ * The whole computation this can be broken down into following steps
+ * 1) Calculate the integer multiplier and integer shift.
+ * 2) Subtract the input integer zero point.
+ * 3) Multiply the fixed point multiplier with quantized tensor.
+ * 4) Round the result.
+ * 5) Right shift the result.
+ * 6) Add the output zero point.
+ * 7) Cast to the out_dtype.
+ */
+Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
+ const Array<IndexExpr>& input_shape) {
+ double double_multiplier = param->input_scale / param->output_scale;
+
+ // Choose high precision datatype to be int64. This is for avoiding overflow
+ // in multiplication of two int32 values.
+ DataType hp_dtype = Int(64);
+
+ // 1) Calculating the integer multiplier and integer shift
+ int32_t fixed_point_multiplier, shift;
+ std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
+ int left_shift = shift > 0 ? shift : 0;
+ int right_shift = shift > 0 ? 0 : -shift;
+
+ // 2) Subtract the input_zero_point
+ auto tensor = Cast(input_tensor, hp_dtype);
+ if (param->input_zero_point != 0) {
+ auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
+ tensor = Subtract(tensor, input_zp);
+ }
+
+ // 3) Multiply the integer multiplier
+ if (left_shift != 0) {
+ tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
+ }
+ // Perform the multiplication in higher precision.
+ // The scalar is a fixed point value of int32 where the decimal point is
+ // between bits 31 and 30. After multiplying with input_tensor, the result is
+ // in int64 where the decimal point is sitting between bits 31 and 30 (from
+ // the right, rightmost bit is bit 0). The computation is performed in higher
+ // precision to avoid overflow in multiplying two int32 values.
+ Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
+ auto multiplied_t = Multiply(tensor, scalar);
+
+ // 4) Find the rounding scalar. This depends on where the final decimal point
+ // sits. As we will be right shifting the multiplied_t, we need to first
+ // calculate the total_right_shift.
+ int total_right_shift = right_shift + 31;
+ int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+
+ tensor = multiplied_t;
+ Expr round_scalar;
+ if (param->rounding == "UPWARD") {
+ round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ } else if (param->rounding == "TONEAREST") {
+ auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+ auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+ auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+
+ auto zero = MakeConstantScalar(hp_dtype, 0);
+ auto zero_t = Full(zero, input_shape, hp_dtype);
+ round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+ }
+ // Add the rounding scalar.
+ tensor = Add(tensor, round_scalar);
+
+ // 5) Simply right shift the result to get the final output.
+ auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+
+ // 6) Add the output zero point.
+ auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
+ auto shifted_int64_t = Add(output_zp, scaled_int64_t);
+
+ // 7) Clip to the out_dtype min/max.
+ auto q_min = GetQmin(param->out_dtype);
+ auto q_max = GetQmax(param->out_dtype);
+ auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
+ return Cast(clipped_t, param->out_dtype);
+}
+
+/*
+ * \brief Forward rewrite the requantize op.
+ * \param ref_call The original call that will be lowered.
+ * \param new_args The new mutated args to the call node.
+ * \param ctx The node context.
+ * \return The sequence of Relay ops for requantize op.
+ * \note Lowering of the requantize operation. The requantize operator converts
+ * one quantized tensor to another quantized tensor. For the output
+ * tensor, we are provided with output scale and zero point. The
+ * computation looks like this
+ *
+ * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+ */
+Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
+ const Array<tvm::relay::Type>& arg_types) {
+ CHECK_EQ(new_args.size(), 1);
+ auto& quantized_data = new_args[0];
+ const auto* param = attrs.as<RequantizeAttrs>();
+ CHECK(param != nullptr);
+
+ // Find input shape.
+ CHECK_EQ(arg_types.size(), 1);
+ auto input_dtype = arg_types[0];
+ auto input_tensor_type = input_dtype.as<TensorTypeNode>();
+ CHECK(input_tensor_type != nullptr) << "Type information missing."
+ << " Please run infer_type pass.";
+ Array<IndexExpr> input_shape = input_tensor_type->shape;
+
+ // Check rounding validity.
+ CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
+ << "QNN requantize supports two rounding modes - UPWARD and "
+ << "TONEAREST";
+ return RequantizeLower(quantized_data, param, input_shape);
+}
+
+/*
+ * \brief Infer shape function of Requantize op.
+ * \param types The types of input args.
+ * \param num_inputs The number of inputs.
+ * \param attrs The op attributes.
+ * \param reporter The type reporter that sets the dtype and shapes.
+ * \return True if the infer shape succeeded.
+ */
+bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto in_dtype = data->dtype;
+ CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32))
+ << "Input type should be an integer but was " << in_dtype;
+
+ const Array<tvm::Expr> oshape = data->shape;
+ // assign output type
+ const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
+ auto out_dtype = param->out_dtype;
+ CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
+ << "Output type should be an integer but was " << out_dtype;
+ reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
+ return true;
+}
+
+// Positional relay function to create qnn requantize operator
+// used by frontend FFI.
+Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale,
+ int32_t output_zero_point, std::string rounding, DataType out_dtype) {
+ auto attrs = make_node<RequantizeAttrs>();
+ attrs->input_scale = std::move(input_scale);
+ attrs->input_zero_point = std::move(input_zero_point);
+ attrs->output_scale = std::move(output_scale);
+ attrs->output_zero_point = std::move(output_zero_point);
+ attrs->rounding = std::move(rounding);
+ attrs->out_dtype = std::move(out_dtype);
+ static const Op& op = Op::Get("qnn.requantize");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+RELAY_REGISTER_OP("qnn.requantize")
+.describe(R"code(Requantize operator.
+The requantize operator converts one quantized tensor to another quantized
+tensor. For the output tensor, we are provided with output scale and zero
+point. The computation looks like this
+
+Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.RequantizeAttrs")
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The quantized input tensor.")
+.set_support_level(11)
+.add_type_rel("Requantize", RequantizeRel)
+.set_attr<FTVMLegalize>("FTVMLegalize", RequantizeLegalize);
+
+TVM_REGISTER_API("relay.qnn.op._make.requantize")
+.set_body_typed(MakeRequantize);
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/util.h
+ * \brief Utility methods needs for quantized ops that can be shared
+ */
+
+#ifndef TVM_RELAY_QNN_UTIL_H_
+#define TVM_RELAY_QNN_UTIL_H_
+
+#include <tvm/expr.h>
+#include <tvm/relay/expr.h>
+#include <limits>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+static inline const int32_t GetQmin(const DataType& dtype) {
+ CHECK_LE(dtype.bits(), 32)
+ << "QNN ops support int32 or lower precision";
+ if (dtype.is_int()) {
+ auto* min_value = as_const_int(dtype.min());
+ CHECK(min_value != nullptr);
+ return static_cast<int32_t>(min_value[0]);
+ } else if (dtype.is_uint()) {
+ auto* min_value = as_const_uint(dtype.min());
+ CHECK(min_value != nullptr);
+ return static_cast<int32_t>(min_value[0]);
+ } else {
+ LOG(FATAL) << "Type not supported " << dtype;
+ return -1; // To hide the warning
+ }
+}
+
+static inline const int32_t GetQmax(const DataType& dtype) {
+ CHECK_LE(dtype.bits(), 32)
+ << "QNN ops support int32 or lower precision";
+ if (dtype.is_int()) {
+ auto* max_value = as_const_int(dtype.max());
+ CHECK(max_value != nullptr);
+ return static_cast<int32_t>(max_value[0]);
+ } else if (dtype.is_uint()) {
+ auto* max_value = as_const_uint(dtype.max());
+ CHECK(max_value != nullptr);
+ return static_cast<int32_t>(max_value[0]);
+ } else {
+ LOG(FATAL) << "Type not supported " << dtype;
+ return -1; // To hide the warning
+ }
+}
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
+#endif // TVM_RELAY_QNN_UTIL_H_
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay.testing import create_workload
+from tvm.contrib import graph_runtime
+
+roundings = ["UPWARD", "TONEAREST"]
+
+def run_infer_type(expr):
+ mod = relay.Module.from_expr(expr)
+ mod = relay.transform.InferType()(mod)
+ entry = mod["main"]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def test_requantize():
+ def verify(mod, goldens):
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(mod, "llvm", params=None)
+ golden_data, golden_output = goldens
+ rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+ rt_mod.set_input("quantized_data",golden_data)
+ rt_mod.set_input(**params)
+ rt_mod.run()
+ res = rt_mod.get_output(0).asnumpy()
+ np.testing.assert_equal(res, golden_output)
+
+ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
+ input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
+ quantized_data = relay.var("quantized_data", shape=data_shape,
+ dtype=data_dtype)
+ mod = relay.qnn.op.requantize(
+ quantized_data,
+ input_scale=input_scale,
+ input_zero_point=input_zero_point,
+ output_scale=output_scale,
+ output_zero_point=output_zero_point,
+ rounding=rounding,
+ out_dtype=out_dtype)
+
+ mod = relay.Function(relay.analysis.free_vars(mod), mod)
+ mod = relay.Module.from_expr(mod)
+ mod = relay.transform.Legalize()(mod)
+ return mod
+
+ def same_scale_test():
+ # Have same scales, everything within range
+ golden_data = np.arange(-100, 100, 1).astype('int32')
+ golden_output = golden_data
+
+ for rounding in roundings:
+ mod = get_mod(data_shape=(200, ),
+ data_dtype='int32',
+ out_dtype="int8",
+ input_scale=0.5,
+ output_scale=0.5,
+ rounding=rounding)
+ verify(mod, (golden_data, golden_output))
+
+ def downscale_test():
+ for rounding in roundings:
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype='int8',
+ input_scale=1,
+ output_scale=16,
+ rounding=rounding)
+
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype('int32')
+ if rounding == "UPWARD":
+ golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+ else:
+ golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
+
+ # Try a different scale
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype="int8",
+ input_scale=1,
+ output_scale=4,
+ rounding=rounding)
+
+ # Try positive values
+ # 2I corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
+ [2, 4, 4, 4, 4, 4, 4, 4, 2])
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype('int32')
+ if rounding == "UPWARD":
+ golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+ [3, 4, 4, 4, 4, 4, 4, 4, 1])
+ else:
+ golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+ [2, 4, 4, 4, 4, 4, 4, 4, 2])
+ verify(mod, (golden_data, golden_output))
+
+ # Try uint8 out_dtype
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype='uint8',
+ input_scale=1,
+ output_scale=16,
+ rounding=rounding)
+
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
+
+ # Try uint8 in_dtyope and uint8 out_dtype
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='uint8',
+ out_dtype='uint8',
+ input_scale=1,
+ output_scale=16,
+ rounding=rounding)
+
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
+
+ def upscale_test():
+ for rounding in roundings:
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype="int8",
+ input_scale=2,
+ output_scale=1,
+ rounding=rounding)
+
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.multiply(2, golden_data)
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype('int32')
+ golden_output = np.multiply(2, golden_data)
+ verify(mod, (golden_data, golden_output))
+
+ def saturation_test():
+ for rounding in roundings:
+ mod = get_mod(data_shape=(16, ),
+ data_dtype='int32',
+ out_dtype="int8",
+ input_scale=0.5,
+ output_scale=0.5,
+ rounding=rounding)
+ golden_data = np.arange(0, 16, 1).astype('int32')
+ golden_data = np.add(120, golden_data)
+ output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
+ 127, 127, 127, 127, 127, 127, 127, 127])
+ golden_output = output
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative numbers
+ golden_data = np.arange(0, -16, -1).astype('int32')
+ golden_data = np.add(-120, golden_data)
+ output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
+ -128, -128, -128, -128, -128, -128, -128, -128])
+ golden_output = output
+ verify(mod, (golden_data, golden_output))
+
+ def zero_point_test():
+ # Output zero point
+ for rounding in roundings:
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype='int8',
+ input_scale=1,
+ output_scale=16,
+ output_zero_point=1,
+ rounding=rounding)
+
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype('int32')
+ golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+ golden_output = np.add(1, golden_output)
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(-32, -64, -1).astype('int32')
+ if rounding == "UPWARD":
+ golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+ else:
+ golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+ golden_output = np.add(1, golden_output)
+ verify(mod, (golden_data, golden_output))
+
+ # Input zero point
+ for rounding in roundings:
+ mod = get_mod(data_shape=(32, ),
+ data_dtype='int32',
+ out_dtype='int8',
+ input_scale=1,
+ output_scale=16,
+ input_zero_point=16,
+ rounding=rounding)
+
+ # Try positive values
+ golden_data = np.arange(32, 64, 1).astype('int32')
+ golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+ golden_output = np.subtract(golden_output, 1)
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative values
+ golden_data = np.arange(-32, -64, -1).astype('int32')
+ if rounding == "UPWARD":
+ golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+ else:
+ golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+ golden_output = np.subtract(golden_output, 1)
+ verify(mod, (golden_data, golden_output))
+
+ same_scale_test()
+ downscale_test()
+ upscale_test()
+ saturation_test()
+ zero_point_test()
+
+if __name__ == "__main__":
+ test_requantize()