* QNN quantize and dequantize operators.
* addressing review comments.
* addressing review comments.
* Adding new line at the end of the file.
* Adhering to styling guidelines.
* Adding name to contributors.
* Fixing lint issue.
* Fixing file name.
* Removing unnecessary code.
- [Haolong Zhang](https://github.com/haolongzhangm)
- [Cody Hao Yu](https://github.com/comaniac)
- [Chris Nuernberger](https://github.com/cnuernber)
+- [Shoubhik Bhattacharya](https://github.com/shoubhik)
}
};
+/*! \brief Attribute for quantize operator */
+struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
+ int32_t output_zero_point;
+ double output_scale;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
+ TVM_ATTR_FIELD(out_dtype)
+ .describe("Output data type, can be one of [int8 or uint8].");
+
+ TVM_ATTR_FIELD(output_zero_point)
+ .describe("The zero_point for the activation of this op.");
+
+ TVM_ATTR_FIELD(output_scale)
+ .describe("The scale for the activation of this op.");
+ }
+};
+
+/*! \brief Attribute for dequantize operator */
+struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
+ int32_t input_zero_point;
+ double input_scale;
+
+ TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
+ TVM_ATTR_FIELD(input_zero_point)
+ .describe("The zero_point for the input tensor of this op.");
+
+ TVM_ATTR_FIELD(input_scale)
+ .describe("The scale for the input tensor of this op.");
+ }
+};
+
} // namespace qnn
} // namespace relay
} // namespace tvm
rounding,
out_dtype)
+
+def quantize(data,
+ output_scale,
+ output_zero_point,
+ out_dtype='int8'):
+ r""" Quantize op
+ This operator takes float32 as input and produces quantized int8 or unit8 as output.
+ The input tensor can be of any shape. The output shape is the same as input shape.
+
+ Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
+ out_dtype::min,
+ out_dtype::max)
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input tensor to be quantized. Can be of type float32.
+ output_zero_point : int
+ The output zero_point.
+ output_scale : float
+ The output scale.
+ input_dtype : str, optional
+ The data type of the input tensor. Can be [int8, uint8]
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+
+ return _make.quantize(data,
+ output_scale,
+ output_zero_point,
+ out_dtype)
+
+
+def dequantize(data,
+ input_scale,
+ input_zero_point):
+ r""" Dequantize op
+ This operator takes quantized int8 and unit8 as input and produces
+ dequantized float32 as output. The output shape is the same as input shape. The input
+ tensor can be of any shape.
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input tensor to be dequantized. Can be of type [int8, uint8].
+ input_zero_point : int
+ The output zero_point.
+ input_scale : float
+ The output scale.
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+
+ return _make.dequantize(data,
+ input_scale,
+ input_zero_point)
def concatenate(data,
input_scales,
input_zero_points,
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file src/relay/qnn/op/dequantize.cc
+ * \brief QNN dequantize operator. Dequantize operator converts from quantized
+ * domain to unquantized domain.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include "../../pass/pattern_util.h"
+#include "../util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
+
+bool DequantizeRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto input_dtype = data->dtype;
+ CHECK(input_dtype == Int(8) || input_dtype == UInt(8))
+ << "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype;
+ const Array<tvm::Expr> oshape = data->shape;
+ // assign output type, output will always be float 32.
+ reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32)));
+ return true;
+}
+
+Expr MakeDequantize(Expr data,
+ double input_scale,
+ int32_t input_zero_point) {
+ auto attrs = make_node<DequantizeAttrs>();
+ attrs->input_scale = input_scale;
+ attrs->input_zero_point = input_zero_point;
+ // real_value = scale * (quantized_value - zero_point)
+ // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
+ static const Op& op = Op::Get("qnn.dequantize");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+Expr DequantizeLower(const Expr& input_tensor,
+ const DequantizeAttrs* attrs) {
+ const auto input_zero_point = MakeConstantScalar(Int(32), attrs->input_zero_point);
+ const auto input_scale = MakeConstantScalar(Float(32), attrs->input_scale);
+ auto shift = Subtract(Cast(input_tensor, Int(32)), input_zero_point);
+ auto scaled_output = Multiply(Cast(shift, Float(32)), input_scale);
+ return scaled_output;
+}
+
+Expr DequantizeLegalize(const Attrs& attrs,
+ const Array<Expr>& new_args,
+ const Array<tvm::relay::Type>& arg_types) {
+ CHECK_EQ(new_args.size(), 1);
+ auto& data = new_args[0];
+ const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
+ CHECK(dequantize_attrs != nullptr);
+ CHECK_EQ(arg_types.size(), 1);
+ return DequantizeLower(data, dequantize_attrs);
+}
+
+RELAY_REGISTER_OP("qnn.dequantize")
+.describe(R"code(Dequantizes the input and produces float32 output.
+The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
+- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.DequantizeAttrs")
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The tensor to dequantize.")
+.set_support_level(11)
+.add_type_rel("Dequantize", DequantizeRel)
+.set_attr<FTVMLegalize>("FTVMLegalize", DequantizeLegalize);
+
+TVM_REGISTER_API("relay.qnn.op._make.dequantize")
+.set_body_typed(MakeDequantize);
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file src/relay/qnn/op/quantize.cc
+ * \brief QNN dequantize operator. Dequantize operator converts from quantized
+ * domain to unquantized domain.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+#include "../../pass/pattern_util.h"
+#include "../util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+TVM_REGISTER_NODE_TYPE(QuantizeAttrs);
+
+bool QuantizeRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto input_dtype = data->dtype;
+ CHECK(input_dtype == Float(32))
+ << "Input type should be one of float32 but was " << input_dtype;
+ const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
+ const Array<tvm::Expr> oshape = data->shape;
+ const DataType out_dtype = quantize_attrs->out_dtype;
+ CHECK(out_dtype == Int(8) || out_dtype == UInt(8))
+ << "Output type should be one of [int8, unit8 ] but was " << out_dtype;
+ // assign output type
+ reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
+ return true;
+}
+
+Expr MakeQuantize(Expr data,
+ double output_scale,
+ int32_t output_zero_point,
+ DataType out_dtype) {
+ auto attrs = make_node<QuantizeAttrs>();
+ attrs->output_scale = output_scale;
+ attrs->output_zero_point = output_zero_point;
+ attrs->out_dtype = std::move(out_dtype);
+ // result_quantized_value = result_zero_point + result_real_value / result_scale.
+ // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
+ static const Op& op = Op::Get("qnn.quantize");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+Expr QuantizeLower(const Expr& input_tensor,
+ const QuantizeAttrs* attrs) {
+ const auto out_dtype = attrs->out_dtype;
+ const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point);
+ const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
+ const int32_t min_val = GetQmin(out_dtype);
+ const int32_t max_val = GetQmax(out_dtype);
+ auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
+ auto add_zero_point = Add(scale_data, output_zero_point);
+ auto clamped_output = Clip(add_zero_point, min_val, max_val);
+ auto clamp_out_dtype = Cast(clamped_output, out_dtype);
+ return clamp_out_dtype;
+}
+
+Expr QuantizeLegalize(const Attrs& attrs,
+ const Array<Expr>& new_args,
+ const Array<tvm::relay::Type>& arg_types) {
+ CHECK_EQ(new_args.size(), 1);
+ auto& data = new_args[0];
+ const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
+ CHECK(quantize_attrs != nullptr);
+
+ CHECK_EQ(arg_types.size(), 1);
+ return QuantizeLower(data, quantize_attrs);
+}
+
+RELAY_REGISTER_OP("qnn.quantize")
+.describe(R"code(Quantizes the input and produces quantized output.
+The input can be either float or quantized(int8, unit8). If the input is float,
+this op takes scale and zero point and quantize the float value to
+quantized output, in int8 or uint8 format. If the input is quantized value,
+the op requantize the input (of a certain type, with a given scale and zero
+point) to the output of the same or different type with a same or different
+scale and zero point.
+- **data**: Tensor of any shape to quantize. The input data can be of floating point
+ or quantized.
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.QuantizeAttrs")
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The tensor to quantize.")
+.set_support_level(11)
+.add_type_rel("Quantize", QuantizeRel)
+.set_attr<FTVMLegalize>("FTVMLegalize", QuantizeLegalize);
+
+TVM_REGISTER_API("relay.qnn.op._make.quantize")
+.set_body_typed(MakeQuantize);
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
- * \file requantize.cc
+ * \file src/relay/qnn/op/requantize.cc
* \brief QNN requantize operator.
*/
const auto* data = types[0].as<TensorTypeNode>();
const auto in_dtype = data->dtype;
CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32))
- << "Input type should be an integer but was " << in_dtype;
+ << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
const Array<tvm::Expr> oshape = data->shape;
// assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
auto out_dtype = param->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
- << "Output type should be an integer but was " << out_dtype;
+ << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true;
}
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.contrib import graph_runtime
+
+def test_dequantize_op():
+
+ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
+ shape = in_data.shape
+ input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+ input_zero_point = quant_args['in_zero_point']
+ input_scale = quant_args['in_scale']
+ quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
+ input_zero_point=input_zero_point)
+ mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+ mod = relay.Module.from_expr(mod)
+ mod = relay.transform.Legalize()(mod)
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(mod, "llvm", params=None)
+ rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+ rt_mod.set_input(input_data=in_data)
+ rt_mod.set_input(**params)
+ rt_mod.run()
+ res = rt_mod.get_output(0).asnumpy()
+ np.testing.assert_equal(res, verify_output_data)
+ assert res.dtype == np.float32
+
+ def test_uint8_to_float32():
+ data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+ .astype('uint8') \
+ .reshape((2,5))
+ output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+ .astype('float32') \
+ .reshape((2,5))
+ quant_args = {"in_zero_point":127, "in_scale":0.5}
+ quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
+ verify_output_data=output)
+
+ def test_int8_to_float32():
+ data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
+ .astype('int8') \
+ .reshape((2,5))
+ output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+ .astype('float32') \
+ .reshape((2,5))
+ quant_args = {"in_zero_point":-1, "in_scale":0.5}
+ quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
+ verify_output_data=output)
+
+ test_uint8_to_float32()
+ test_int8_to_float32()
+
+
+if __name__ == "__main__":
+ test_dequantize_op()
+
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.contrib import graph_runtime
+
+def test_quantize_op():
+
+ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data):
+ shape = in_data.shape
+ input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+ output_zero_point = quant_args['out_zero_point']
+ output_scale = quant_args['out_scale']
+ quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
+ output_zero_point=output_zero_point,out_dtype=out_dtype)
+ mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+ mod = relay.Module.from_expr(mod)
+ mod = relay.transform.Legalize()(mod)
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(mod, "llvm", params=None)
+ rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+ rt_mod.set_input(input_data=in_data)
+ rt_mod.set_input(**params)
+ rt_mod.run()
+ res = rt_mod.get_output(0).asnumpy()
+ np.testing.assert_equal(res, verify_output_data)
+ assert res.dtype == out_dtype
+
+ def test_float32_to_uint8():
+ data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+ .astype('float32') \
+ .reshape((2,5))
+ output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+ .astype('uint8') \
+ .reshape((2,5))
+ quant_args = {"out_zero_point":127, "out_scale":0.5}
+ quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data,
+ verify_output_data=output)
+
+ def test_float32_to_int8():
+ data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+ .astype('float32') \
+ .reshape((2,5))
+ output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
+ .astype('int8') \
+ .reshape((2,5))
+ quant_args = {"out_zero_point":-1, "out_scale":0.5}
+ quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data,
+ verify_output_data=output)
+
+ test_float32_to_uint8()
+ test_float32_to_int8()
+
+if __name__ == "__main__":
+ test_quantize_op()
import tvm
import numpy as np
from tvm import relay
-from tvm.relay.testing import create_workload
from tvm.contrib import graph_runtime
roundings = ["UPWARD", "TONEAREST"]
-def run_infer_type(expr):
- mod = relay.Module.from_expr(expr)
- mod = relay.transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
-
-
def test_requantize():
def verify(mod, goldens):
with relay.build_config(opt_level=3):