'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
+ 'DEQUANTIZE': self.convert_dequantize,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
'ELU': self.convert_elu,
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
+ 'QUANTIZE': self.convert_quantize,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
'REDUCE_MIN': self.convert_reduce_min,
except ImportError:
raise ImportError("The tflite package must be installed")
+ if tensor_type == TensorType.INT8:
+ return "int8"
if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT32:
return out
+ def convert_quantize(self, op):
+ """Convert TFLite Quantize"""
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+
+ # The output must be quantized
+ assert output_tensor.qnn_params
+ # Quantize the input
+ out = self.quantize(in_expr, output_tensor)
+
+ return out
+
+ def convert_dequantize(self, op):
+ """Convert TFLite Dequantize"""
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ # The input must be quantized
+ assert input_tensor.qnn_params
+ # Dequantize the input.
+ out = self.dequantize(in_expr, input_tensor)
+
+ return out
+
def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
#######################################################################
+# Quantize/DeQuantize
+# -------------------
+
+def _test_quantize_dequantize(data):
+ """ One iteration of quantize and dequantize """
+
+ # Define a dummy model
+ data_in = tf.keras.layers.Input(shape=data.shape[1:])
+ act_func = tf.keras.layers.Activation('linear')
+ keras_model = tf.keras.models.Model(data_in, act_func(data_in))
+
+ # Load the model
+ converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+
+ # To create quantized values with dynamic range of activations, needs representative dataset
+ def representative_data_gen():
+ for i in range(100):
+ yield [data]
+
+ converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+ converter.representative_dataset = representative_data_gen
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.uint8
+ converter.inference_output_type = tf.uint8
+
+ # Convert the model to TensorFlow Lite format
+ tflite_model_quant = converter.convert()
+
+ tflite_output = run_tflite_graph(tflite_model_quant, data)
+ tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+ rtol=1e-5, atol=1e-5)
+
+
+def test_forward_quantize_dequantize():
+ """ Quantize Dequantize """
+ data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
+ _test_quantize_dequantize(data)
+
+
+#######################################################################
# Pad
# ---
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()
+ test_forward_quantize_dequantize()
# NN
test_forward_convolution()