From 162a29e0b82aef9c3013e360a43a4b14c3c652d9 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 28 May 2020 08:40:58 +0530 Subject: [PATCH] [TFLITE]Quantize & Dequantize op (#5394) * [TFLITE]Quantize & Dequantize op * Testcases added * Review comment fixed --- python/tvm/relay/frontend/tflite.py | 38 ++++++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index cb10ce5..9414314 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -75,6 +75,7 @@ class OperatorConverter(object): 'COS': self.convert_cos, 'DEPTH_TO_SPACE': self.convert_depth_to_space, 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, + 'DEQUANTIZE': self.convert_dequantize, 'DETECTION_POSTPROCESS': self.convert_detection_postprocess, 'DIV': self.convert_div, 'ELU': self.convert_elu, @@ -112,6 +113,7 @@ class OperatorConverter(object): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, + 'QUANTIZE': self.convert_quantize, 'REDUCE_ANY': self.convert_reduce_any, 'REDUCE_MAX': self.convert_reduce_max, 'REDUCE_MIN': self.convert_reduce_min, @@ -277,6 +279,8 @@ class OperatorConverter(object): except ImportError: raise ImportError("The tflite package must be installed") + if tensor_type == TensorType.INT8: + return "int8" if tensor_type == TensorType.UINT8: return "uint8" if tensor_type == TensorType.FLOAT32: @@ -2355,6 +2359,40 @@ class OperatorConverter(object): return out + def convert_quantize(self, op): + """Convert TFLite Quantize""" + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + # The output must be quantized + assert output_tensor.qnn_params + # Quantize the input + out = self.quantize(in_expr, output_tensor) + + return out + + def convert_dequantize(self, op): + """Convert TFLite Dequantize""" + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # The input must be quantized + assert input_tensor.qnn_params + # Dequantize the input. + out = self.dequantize(in_expr, input_tensor) + + return out + def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" flexbuffer = op.CustomOptionsAsNumpy().tobytes() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7a8437a..a68fd90 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1565,6 +1565,48 @@ def test_forward_squeeze(): ####################################################################### +# Quantize/DeQuantize +# ------------------- + +def _test_quantize_dequantize(data): + """ One iteration of quantize and dequantize """ + + # Define a dummy model + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + act_func = tf.keras.layers.Activation('linear') + keras_model = tf.keras.models.Model(data_in, act_func(data_in)) + + # Load the model + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(100): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + # Convert the model to TensorFlow Lite format + tflite_model_quant = converter.convert() + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + + +def test_forward_quantize_dequantize(): + """ Quantize Dequantize """ + data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32") + if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + _test_quantize_dequantize(data) + + +####################################################################### # Pad # --- @@ -2264,6 +2306,7 @@ if __name__ == '__main__': test_forward_depthtospace() test_forward_spacetodepth() test_forward_select() + test_forward_quantize_dequantize() # NN test_forward_convolution() -- 2.7.4