[Frontend][TFLite] Dynamically calculate input_stats of any fake_quant range (#4789)
authorIna Dobreva <55383260+inadob@users.noreply.github.com>
Wed, 5 Feb 2020 19:51:59 +0000 (19:51 +0000)
committerGitHub <noreply@github.com>
Wed, 5 Feb 2020 19:51:59 +0000 (11:51 -0800)
* [TFLite] Dynamically calculate input_stats of any fake_quant range

* pass the input range to the convertor and caclulate (mean, scale) there
* change the range of the second tensor in elemwise operations
  so that we test inputs with different quant params
* change the possible output range for elemwise ops wrt the updated ranges
* update the comments for (m, s) calculations
* add input range dict to reduce_mean op

* Apply requested changes

* add exception handling for zero division in input_stats
* fix range of the input tensor in elemwsie

tests/python/frontend/tflite/test_forward.py

index 9835bfc..aa29cf5 100644 (file)
@@ -123,7 +123,8 @@ def run_tflite_graph(tflite_model_buf, input_data):
 
 
 def compare_tflite_with_tvm(in_data, in_name, input_tensors,
-                            output_tensors, init_global_variables=False, out_names=None, quantized=False):
+                            output_tensors, init_global_variables=False,
+                            out_names=None, quantized=False, input_range=None):
     """Generic function to generate and compare TFLite and TVM output"""
     in_data = convert_to_list(in_data)
     in_name = convert_to_list(in_name)
@@ -143,11 +144,16 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
             converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
             input_arrays = converter.get_input_arrays()
             input_stats = {}
-            # hardcode the mean_values and std_dev_values (m,s) to be the same
-            # if all inputs are in (float_min; float_max) == (-100, 100)
+            # calculate the mean and quantization scale for every input tensor,
+            # with respect to its fp32 input range, defined in fake_quant.
             # s = 255/(fmax-fmin);  m = -fmin*s (the zero point)
             for i in input_arrays:
-                input_stats[i] = (128., 1.275)
+                try:
+                    quant_scale = 255 / (input_range[i][1] - input_range[i][0])
+                except ZeroDivisionError:
+                    raise ZeroDivisionError('Min and max of the input range for tensor ' + i + ' can\'t be equal')
+                mean = - input_range[i][0] * quant_scale
+                input_stats[i] = (mean, quant_scale)
             converter.quantized_input_stats = input_stats
 
         tflite_model_buffer = converter.convert()
@@ -757,13 +763,15 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
         if quantized:
             # fake_quant will keep the tensors in float32 until the conversion in the session
             inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0"),
-                        tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-100, max=100, name="inq_1")]
+                        tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-50, max=50, name="inq_1")]
+            input_range = {'inq_0': (-100, 100), 'inq_1': (-50, 50)}
             out = math_op(inq_data[0], inq_data[1])
             out = with_fused_activation_function(out, fused_activation_function)
-            # set the quantized output range with respect to the operation
+            # set the fp32 output range with respect to the operation
             out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
-            compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out], quantized=True)
+            compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out],
+                                    quantized=True, input_range=input_range)
         else:
             out = math_op(in_data[0], in_data[1])
             out = with_fused_activation_function(out, fused_activation_function)
@@ -775,13 +783,14 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
 
         if quantized:
             inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
-            inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-100, max=100, name="const_tensor")
+            inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-50, max=50, name="const_tensor")
+            input_range = {'inq_0': (-100, 100)}
             # the 2nd tensor is treated as constant and directly added as part of the operation
             out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'))
             out = with_fused_activation_function(out, fused_activation_function)
             out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
-            compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True)
+            compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True, input_range=input_range)
         else:
             out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
             out = with_fused_activation_function(out, fused_activation_function)
@@ -793,13 +802,14 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
 
         if quantized:
             inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
-            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")]
+            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-50, max=50, name="inq_1")]
+            input_range = {'inq_1': (-50, 50)}
             # the 1st tensor is treated as constant and directly added as part of the operation
             out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
             out = with_fused_activation_function(out, fused_activation_function)
             out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
             out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
-            compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True)
+            compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True, input_range=input_range)
         else:
             out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
             out = with_fused_activation_function(out, fused_activation_function)
@@ -920,11 +930,11 @@ def _test_forward_elemwise_quantized(testop):
             np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True, qnn_op=testop)
 
 def _test_elemwise_qnn_out_range(qnn_op):
-    # set the fake_quant output range if input tensors are in [-100, 100] float32
+    # set the fake_quant output range with respect to the input tensors float32 range
     qnn_out_range = {
-        _test_add: (-200, 200),
-        _test_sub: (-200, 200),
-        _test_mul: (-1e+4, 1e+4),
+        _test_add: (-150, 150),
+        _test_sub: (-150, 150),
+        _test_mul: (-5e+3, 5e+3),
     }
 
     return qnn_out_range[qnn_op]
@@ -994,9 +1004,10 @@ def _test_reduce_quantize(math_op, data, keep_dims=None):
     with tf.Graph().as_default():
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name='in')]
         inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
+        input_range = {'inq_0': (-100, 100)}
         out = math_op(inq_data, data[1], keep_dims)
         out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
-        compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True)
+        compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True, input_range=input_range)
 
 
 #######################################################################