[TFLITE]Quantize & Dequantize op (#5394)
authorSamuel <siju.samuel@huawei.com>
Thu, 28 May 2020 03:10:58 +0000 (08:40 +0530)
committerGitHub <noreply@github.com>
Thu, 28 May 2020 03:10:58 +0000 (11:10 +0800)
* [TFLITE]Quantize & Dequantize op

* Testcases added

* Review comment fixed

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index cb10ce5..9414314 100644 (file)
@@ -75,6 +75,7 @@ class OperatorConverter(object):
             'COS': self.convert_cos,
             'DEPTH_TO_SPACE': self.convert_depth_to_space,
             'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
+            'DEQUANTIZE': self.convert_dequantize,
             'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
             'DIV': self.convert_div,
             'ELU': self.convert_elu,
@@ -112,6 +113,7 @@ class OperatorConverter(object):
             'PAD': self.convert_pad,
             'POW': self.convert_pow,
             'PRELU': self.convert_prelu,
+            'QUANTIZE': self.convert_quantize,
             'REDUCE_ANY': self.convert_reduce_any,
             'REDUCE_MAX': self.convert_reduce_max,
             'REDUCE_MIN': self.convert_reduce_min,
@@ -277,6 +279,8 @@ class OperatorConverter(object):
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
+        if tensor_type == TensorType.INT8:
+            return "int8"
         if tensor_type == TensorType.UINT8:
             return "uint8"
         if tensor_type == TensorType.FLOAT32:
@@ -2355,6 +2359,40 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_quantize(self, op):
+        """Convert TFLite Quantize"""
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+
+        # The output must be quantized
+        assert output_tensor.qnn_params
+        # Quantize the input
+        out = self.quantize(in_expr, output_tensor)
+
+        return out
+
+    def convert_dequantize(self, op):
+        """Convert TFLite Dequantize"""
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        # The input must be quantized
+        assert input_tensor.qnn_params
+        # Dequantize the input.
+        out = self.dequantize(in_expr, input_tensor)
+
+        return out
+
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
         flexbuffer = op.CustomOptionsAsNumpy().tobytes()
index 7a8437a..a68fd90 100644 (file)
@@ -1565,6 +1565,48 @@ def test_forward_squeeze():
 
 
 #######################################################################
+# Quantize/DeQuantize
+# -------------------
+
+def _test_quantize_dequantize(data):
+    """ One iteration of quantize and dequantize """
+
+    # Define a dummy model
+    data_in = tf.keras.layers.Input(shape=data.shape[1:])
+    act_func =  tf.keras.layers.Activation('linear')
+    keras_model = tf.keras.models.Model(data_in, act_func(data_in))
+
+    # Load the model
+    converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+
+    # To create quantized values with dynamic range of activations, needs representative dataset
+    def representative_data_gen():
+        for i in range(100):
+            yield [data]
+
+    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+    converter.representative_dataset = representative_data_gen
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+    converter.inference_input_type = tf.uint8
+    converter.inference_output_type = tf.uint8
+
+    # Convert the model to TensorFlow Lite format
+    tflite_model_quant = converter.convert()
+
+    tflite_output = run_tflite_graph(tflite_model_quant, data)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
+    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                rtol=1e-5, atol=1e-5)
+
+
+def test_forward_quantize_dequantize():
+    """ Quantize Dequantize """
+    data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
+        _test_quantize_dequantize(data)
+
+
+#######################################################################
 # Pad
 # ---
 
@@ -2264,6 +2306,7 @@ if __name__ == '__main__':
     test_forward_depthtospace()
     test_forward_spacetodepth()
     test_forward_select()
+    test_forward_quantize_dequantize()
 
     # NN
     test_forward_convolution()