quanitze operation expanded to take const argument (#6127)
authorDmitriy Smirnov <dmitriy.smirnov@arm.com>
Fri, 28 Aug 2020 18:46:13 +0000 (19:46 +0100)
committerGitHub <noreply@github.com>
Fri, 28 Aug 2020 18:46:13 +0000 (11:46 -0700)
* quanitze operation expanded to take const argument

* amendments

used get_tensor_expr, added _test_forward_quantize_dequantize_const test

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 31ff871..1981b5d 100644 (file)
@@ -2764,7 +2764,7 @@ class OperatorConverter(object):
         assert len(input_tensors) == 1, "input tensors length should be 1"
         input_tensor = input_tensors[0]
         input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
-        in_expr = self.get_expr(input_tensor.tensor_idx)
+        in_expr = self.get_tensor_expr(input_tensor)
 
         output_tensors = self.get_output_tensors(op)
         assert len(output_tensors) == 1, "output tensors length should be 1"
index 70a629d..5496dfe 100644 (file)
@@ -1907,11 +1907,38 @@ def _test_quantize_dequantize(data):
                                 rtol=1e-5, atol=1e-2)
 
 
+def _test_quantize_dequantize_const(data):
+    """ One iteration of quantize and dequantize """
+
+    # Keras model to force TFLite converter to insert 2 TFLite quantize ops.
+    # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
+    # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
+    data_in = tf.keras.layers.Input(shape=data.shape[1:])
+    relu = tf.keras.layers.ReLU()(data_in)
+    add = tf.keras.layers.Add()([data, relu])
+    concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
+    keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
+    input_name = data_in.name.split(":")[0]
+
+    # To create quantized values with dynamic range of activations, needs representative dataset
+    def representative_data_gen():
+        for i in range(1):
+            yield [data]
+
+    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+    tflite_output = run_tflite_graph(tflite_model_quant, data)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                rtol=1e-5, atol=1e-2)
+
+
 def test_forward_quantize_dequantize():
     """ Quantize Dequantize """
     data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
     if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
         _test_quantize_dequantize(data)
+        _test_quantize_dequantize_const(data)
 
 
 #######################################################################