- Adding support for Mxnet flavored dequantization for both default and using MKLDNN...
authorshoubhik <shoubhikbhatti@gmail.com>
Thu, 10 Oct 2019 19:52:49 +0000 (12:52 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 10 Oct 2019 19:52:49 +0000 (12:52 -0700)
- Added tests for new methods added.

python/tvm/relay/frontend/__init__.py
python/tvm/relay/frontend/mxnet_qnn_op_utils.py [new file with mode: 0644]
tests/python/frontend/mxnet/test_qnn_ops_utils.py [new file with mode: 0644]

index 76761fd..e623341 100644 (file)
@@ -24,6 +24,7 @@ for Relay.
 from __future__ import absolute_import
 
 from .mxnet import from_mxnet
+from .mxnet_qnn_op_utils import dequantize_mxnet_min_max
 from .keras import from_keras
 from .onnx import from_onnx
 from .tflite import from_tflite
diff --git a/python/tvm/relay/frontend/mxnet_qnn_op_utils.py b/python/tvm/relay/frontend/mxnet_qnn_op_utils.py
new file mode 100644 (file)
index 0000000..e2aaa79
--- /dev/null
@@ -0,0 +1,246 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return
+"""MXNet qnn dialect helper methods for MXNet specific implementations of more
+   generic qnn supported ops.
+"""
+
+import numpy as np
+from tvm.relay.qnn.op.qnn import dequantize
+
+zero_centered_uint8_quantized_range = np.float32(255)
+zero_centered_int8_quantized_range = np.float32(127)
+
+
+def _dequantize_zero_centered(data,
+                              data_min,
+                              data_max,
+                              quantized_range):
+    r"""Dequantizes the given data tensor by calculating the scale
+    using the MKLDNN formula `max(abs(data_min, data_max))/quantized_range`.
+    Where quantized_range is 255 for uint8 and 127 for int8. The `data_min`
+    and `data_max` are the min and max to use for the `data` tensor elements.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type {int8 or uint8}.
+    data_min : float
+        The minimum to use data elements.
+    data_max : float
+        The maximum to use for data elements.
+    quantized_range : float
+        255 for uint8 and 127 for int8. This is the data type range.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    real_range = np.max([np.abs(np.float32(data_min)),
+                         np.abs(np.float32(data_max))])
+    scale = np.divide(real_range, quantized_range)
+    zero_point = 0
+    return dequantize(data, scale, zero_point)
+
+
+def _dequantize_mkldnn_min_max_int8(data,
+                                    imin_range,
+                                    imax_range):
+    r"""Dequantizes the given `data` in {int8 or uint8} and the given
+    min and max ranges and the output data type is `float32`.
+    The method of dequantizing is described here - https://tinyurl.com/y5k6fz5w.
+    We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67
+    but compute the `scale` and `zero_point` to fit our equation.
+    Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN
+    stores the min and max from which we calculate the scale and zero_point.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type float32.
+    imin_range : float
+        The minimum to use data elements.
+    imax_range : float
+        The maximum to use for data elements.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _dequantize_zero_centered(data,
+                                     data_min=imin_range,
+                                     data_max=imax_range,
+                                     quantized_range=zero_centered_int8_quantized_range)
+
+
+def _dequantize_mkldnn_min_max_uint8(data,
+                                     imin_range,
+                                     imax_range):
+    r"""Dequantizes the given `data` in {int8 or uint8} and the given
+    min and max ranges and the output data type is `float32`.
+    The method of dequantize is described here - https://tinyurl.com/y5k6fz5w.
+    We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67
+    but compute the `scale` and `zero_point` to fit our equation.
+    Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN
+    stores the min and max from which we calculate the scale and zero_point.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type float32.
+    imin_range : float
+        The minimum to use data elements.
+    imax_range : float
+        The maximum to use for data elements.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _dequantize_zero_centered(data,
+                                     data_min=imin_range,
+                                     data_max=imax_range,
+                                     quantized_range=zero_centered_uint8_quantized_range)
+
+
+def _dequantize_mxnet_min_max_int8(data,
+                                   imin_range,
+                                   imax_range):
+    r"""Deuantizes the given `data` in {int8 or uint8} and the given
+    min and max ranges and the output data type is `float32`.
+    The method of dequantization is described here - https://tinyurl.com/y4d7hrzf.
+    We use our default dequantize implementation from src/relay/qnn/op/dequantize.cc:67
+    but compute the `scale` and `zero_point` to fit our equation.
+    Unlike in TFLite where we get the scale and zero_point from the model, Mxnet
+    stores the min and max from which we calculate the scale and zero_point.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type float32.
+    imin_range : float
+        The minimum to use data elements.
+    imax_range : float
+        The maximum to use for data elements.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _dequantize_zero_centered(data,
+                                     data_min=imin_range,
+                                     data_max=imax_range,
+                                     quantized_range=zero_centered_int8_quantized_range)
+
+
+def _dequantize_mxnet_min_max_uint8(data,
+                                    imin_range,
+                                    imax_range):
+    r"""Dequantizes the given `data` in {int8 or uint8} and the given
+    min and max ranges and the output data type is `float32`.
+    The method of dequantizing is described here - https://tinyurl.com/y4d7hrzf.
+    We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67
+    but compute the `scale` and `zero_point` to fit our equation.
+    Unlike in TFLite where we get the scale and zero_point from the model, Mxnet
+    stores the min and max from which we calculate the scale and zero_point.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type float32.
+    imin_range : float
+        The minimum to use data elements.
+    imax_range : float
+        The maximum to use for data elements.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    iinfo = np.iinfo(np.uint8)
+    min_limit = np.float64(iinfo.min)
+    max_limit = np.float64(iinfo.max)
+    imin_range = np.float64(imin_range)
+    imax_range = np.float64(imax_range)
+    scale = np.divide((imax_range - imin_range),
+                      (max_limit - min_limit))
+    zero_point = np.int(-1 * np.divide(imin_range, scale))
+    return dequantize(data, scale, zero_point)
+
+
+def dequantize_mxnet_min_max(data,
+                             min_range,
+                             max_range,
+                             in_dtype='int8',
+                             use_mkldnn=False):
+    r"""Dequantizes the given `data` in {int8 or uint8} and the given
+    min and max ranges. The output data type is float32.
+    Only `float32` is supported as output data types.
+    The input data type is expected to be {int8 or uint8}.
+    Mxnet has two different flavors for dequantization 1) Default 2)MKLDNN.
+    To get the second one Mxnet must be built with MKLDNN during compile time.
+    Users can choose either of the implementation for TVM runtime.
+    The main difference between the two implementation is that MKLDNN is centered
+    around 0 and the default implementation for uint8 is not.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input tensor to be quantized. Can be of type float32.
+    min_range : float
+        The minimum to use data elements for the output.
+    max_range : float
+        The maximum to use for data elements for the output.
+    in_dtype: str, optional
+        The input data type, can be 'int8' or 'uint8'
+    use_mkldnn: bool, optional
+        If True then uses MKLDNN quantization implementation otherwise
+        will use default implementation.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    if in_dtype == 'uint8':
+        if use_mkldnn:
+            return _dequantize_mkldnn_min_max_uint8(data,
+                                                    min_range,
+                                                    max_range)
+        else:
+            return _dequantize_mxnet_min_max_uint8(data,
+                                                   min_range,
+                                                   max_range)
+    elif in_dtype == 'int8':
+        if use_mkldnn:
+            return _dequantize_mkldnn_min_max_int8(data, min_range, max_range)
+        else:
+            return _dequantize_mxnet_min_max_int8(data, min_range, max_range)
+    else:
+        raise ValueError(
+            "Expected out_dtype to be int8 or uint8 but was  %s" % in_dtype)
diff --git a/tests/python/frontend/mxnet/test_qnn_ops_utils.py b/tests/python/frontend/mxnet/test_qnn_ops_utils.py
new file mode 100644 (file)
index 0000000..78c9692
--- /dev/null
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.contrib import graph_runtime
+
+
+def test_mxnet_dequantize_op():
+
+    def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
+        shape = in_data.shape
+        input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+        min_range = quant_args['min_range']
+        max_range = quant_args['max_range']
+        quantized_output = \
+            relay.frontend.dequantize_mxnet_min_max(input_data,
+                                                    min_range=min_range,
+                                                    max_range=max_range,
+                                                    in_dtype=in_dtype)
+        mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+        mod = relay.Module.from_expr(mod)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(mod, "llvm", params=None)
+            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+            rt_mod.set_input(input_data=in_data)
+            rt_mod.set_input(**params)
+            rt_mod.run()
+            res = rt_mod.get_output(0).asnumpy()
+            assert np.allclose(res, verify_output_data, )
+            assert res.dtype == np.float32
+
+    def test_uint8_to_float32():
+        data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+            .astype('uint8') \
+            .reshape((2, 5))
+        output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
+            .astype('float32') \
+            .reshape((2, 5))
+        quant_args = {"min_range": -63.5, "max_range": 64}
+        quantize_test_driver(in_dtype='uint8',
+                             quant_args=quant_args,
+                             in_data=data,
+                             verify_output_data=output)
+
+    def test_int8_to_float32():
+        data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \
+            .astype('int8') \
+            .reshape((2, 5))
+        output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316,
+                           61.984253, 62.48819, 62.992126, 63.496063, 64.]) \
+            .astype('float32') \
+            .reshape((2, 5))
+        quant_args = {"min_range": -63.5, "max_range": 64}
+        quantize_test_driver(in_dtype='int8',
+                             quant_args=quant_args,
+                             in_data=data,
+                             verify_output_data=output)
+
+    test_uint8_to_float32()
+    test_int8_to_float32()
+
+
+def test_mkldnn_dequantize_op():
+
+    def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
+        shape = in_data.shape
+        input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
+        min_range = quant_args['min_range']
+        max_range = quant_args['max_range']
+        quantized_output = \
+            relay.frontend.dequantize_mxnet_min_max(input_data,
+                                                    min_range=min_range,
+                                                    max_range=max_range,
+                                                    in_dtype=in_dtype,
+                                                    use_mkldnn=True)
+        mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
+        mod = relay.Module.from_expr(mod)
+        mod = relay.qnn.transform.CanonicalizeOps()(mod)
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(mod, "llvm", params=None)
+            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+            rt_mod.set_input(input_data=in_data)
+            rt_mod.set_input(**params)
+            rt_mod.run()
+            res = rt_mod.get_output(0).asnumpy()
+            # print(res)
+            # np.testing.assert_equal(res, verify_output_data)
+            assert np.allclose(res, verify_output_data, )
+            assert res.dtype == np.float32
+
+    def test_uint8_to_float32():
+        data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
+            .astype('uint8') \
+            .reshape((2, 5))
+        output = np.array([0., 0.2509804, 0.5019608, 0.75294125, 1.0039216,
+                           62.996082, 63.247063, 63.498043, 63.749023, 64.]) \
+            .astype('float32') \
+            .reshape((2, 5))
+        quant_args = {"min_range": -63.5, "max_range": 64}
+        quantize_test_driver(in_dtype='uint8',
+                             quant_args=quant_args,
+                             in_data=data,
+                             verify_output_data=output)
+
+    def test_int8_to_float32():
+        data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \
+            .astype('int8') \
+            .reshape((2, 5))
+        output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316,
+                           61.984253, 62.48819, 62.992126, 63.496063, 64.]) \
+            .astype('float32') \
+            .reshape((2, 5))
+        quant_args = {"min_range": -63.5, "max_range": 64}
+        quantize_test_driver(in_dtype='int8',
+                             quant_args=quant_args,
+                             in_data=data,
+                             verify_output_data=output)
+
+    test_uint8_to_float32()
+    test_int8_to_float32()
+
+
+if __name__ == "__main__":
+    test_mxnet_dequantize_op()
+    test_mkldnn_dequantize_op()