[TFLite] Using real image for QNN testing. (#4816)
authorAnimesh Jain <anijain@umich.edu>
Tue, 11 Feb 2020 02:56:37 +0000 (18:56 -0800)
committerGitHub <noreply@github.com>
Tue, 11 Feb 2020 02:56:37 +0000 (10:56 +0800)
* [TFLite] Using real image for QNN testing.

* Setting seed for SSD mobilenet for fixed input.

* Support quantized Pad op.

* Remove unnnecessary line.

* Ina comments.

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 5d7d967..ab63047 100644 (file)
@@ -1179,10 +1179,14 @@ class OperatorConverter(object):
             pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
             do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
             if do_pad:
+                pad_value = 0
+                if input_tensor.qnn_params:
+                    pad_value = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
                 in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
                                                               (pad_top, pad_bottom),
                                                               (pad_left, pad_right),
-                                                              (0, 0)))
+                                                              (0, 0)), pad_value=float(pad_value))
+
         else:
             raise tvm.error.OpAttributeUnImplemented(
                 'Padding format {} is not supported for operator Conv.'.format(padding))
@@ -1476,8 +1480,19 @@ class OperatorConverter(object):
         # convert list of lists to tuple of tuples
         paddings = tuple(tuple(l) for l in pad_list)
 
-        # Use default pad_value 0 because TFLite PAD does not support constant_values parameter
-        out = _op.nn.pad(in_expr, paddings)
+        # Set the pad value
+        pad_value = 0
+        if input_tensor.qnn_params:
+            # Check that input and output tensor have same qnn params.
+            output_tensors = self.get_output_tensors(op)
+            output_tensor = output_tensors[0]
+            assert self.has_same_qnn_params(input_tensor, output_tensor), \
+                    "TFLite reshape requires input and output scale and zero points to be equal"
+
+            # The pad value for quantized pad is the input zero point.
+            pad_value = float(input_tensor.qnn_params['zero_point'].data.asnumpy())
+
+        out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value)
         return out
 
     def convert_mirror_pad(self, op):
index ad1abc2..ccb8b87 100644 (file)
@@ -42,6 +42,9 @@ from tvm.contrib.download import download_testdata
 import tvm.relay.testing.tf as tf_testing
 from packaging import version as package_version
 
+from PIL import Image
+import os
+
 #######################################################################
 # Generic run functions for TVM & TFLite
 # --------------------------------------
@@ -50,6 +53,20 @@ def convert_to_list(x):
         x = [x]
     return x
 
+
+#######################################################################
+# Get a real image for e2e testing.
+# --------------------------------------
+def get_real_image(im_height, im_width):
+    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+    img_name = 'elephant-299.jpg'
+    image_url = os.path.join(repo_base, img_name)
+    img_path = download_testdata(image_url, img_name, module='data')
+    image = Image.open(img_path).resize((im_height, im_width))
+    x = np.array(image).astype('uint8')
+    data = np.reshape(x, (1, im_height, im_width, 3))
+    return data
+
 def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                   out_names=None):
     """ Generic function to compile on relay and execute on tvm """
@@ -1139,16 +1156,28 @@ def test_forward_squeeze():
 # Pad
 # ---
 
-def _test_pad(data, mode="CONSTANT"):
+def _test_pad(data, mode="CONSTANT", quantized=False):
     """ One iteration of PAD """
 
     assert len(data) == 2
 
     # Test with tensor and constant
     with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
-        out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
-        compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
+        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')]
+
+        if quantized:
+            # fake_quant will keep the tensors in float32 until the conversion in the session
+            input_range = {'inq_0': (-100, 100)}
+            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
+                                                                     min=-100,
+                                                                     max=100,
+                                                                     name="inq_0")]
+            out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
+            compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True,
+                                    input_range=input_range)
+        else:
+            out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
+            compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
 
 
 def test_forward_pad():
@@ -1165,6 +1194,8 @@ def test_forward_pad():
                np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
     _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
                np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")
+    _test_pad([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+               np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True)
 
 
 #######################################################################
@@ -1425,10 +1456,12 @@ def test_forward_qnn_inception_v1_net():
         "inception_v1_224_quant.tflite")
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    # Checking the labels because the requantize implementation is different between TFLite and
-    # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
-    np.random.seed(0)
-    data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
+
+    # Test image. Checking the labels because the requantize implementation is different between
+    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+    # labels. Also, giving a real image, instead of random inputs.
+    data = get_real_image(224, 224)
+
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1445,10 +1478,12 @@ def test_forward_qnn_mobilenet_v1_net():
         "mobilenet_v1_1.0_224_quant.tflite")
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    # Checking the labels because the requantize implementation is different between TFLite and
-    # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
-    np.random.seed(0)
-    data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
+
+    # Test image. Checking the labels because the requantize implementation is different between
+    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+    # labels. Also, giving a real image, instead of random inputs.
+    data = get_real_image(224, 224)
+
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1465,10 +1500,12 @@ def test_forward_qnn_mobilenet_v2_net():
         "mobilenet_v2_1.0_224_quant.tflite")
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
-    # Checking the labels because the requantize implementation is different between TFLite and
-    # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
-    np.random.seed(0)
-    data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
+
+    # Test image. Checking the labels because the requantize implementation is different between
+    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+    # labels. Also, giving a real image, instead of random inputs.
+    data = get_real_image(224, 224)
+
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tflite_predictions = np.squeeze(tflite_output)
     tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1489,6 +1526,7 @@ def test_forward_ssd_mobilenet_v1():
         "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
     with open(tflite_model_file, "rb") as f:
         tflite_model_buf = f.read()
+    np.random.seed(0)
     data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
     tflite_output = run_tflite_graph(tflite_model_buf, data)
     tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2)