qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
- scale = float(tflite_qnn_params.ScaleAsNumpy())
- zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
+ # TFLite supports both per-tensor and per-axis (aka channel) quantization. For
+ # per-tensor quantization, scale and zero points are scalar values. For per-axis
+ # quantization, scale and zero points for the weights are tensors (activations are
+ # per-tensor quantized). However, the TFLite quantization spec puts restrictions on
+ # zero points for per-axis quantization. Specifically, the zero point is a tensor
+ # but all values are 0. More information can be found here -
+ # https://www.tensorflow.org/lite/performance/quantization_spec
+
+ tflite_scale = tflite_qnn_params.ScaleAsNumpy()
+ tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
+ is_qnn_params_valid = True
+
+ # Handle Per-axis and per-tensor cases
+ if isinstance(tflite_scale, np.ndarray):
+ assert isinstance(tflite_zero_point, np.ndarray)
+
+ # Tensor - Per-axis quantization
+ if tflite_scale.size != 1 and tflite_zero_point.size != 1:
+ scale = tflite_scale
+ # Ensure that all zero points are zeros
+ zero_point = tflite_zero_point
+ if not np.all(zero_point == 0):
+ raise tvm.error.OpAttributeInvalid(\
+ "TFLite per-axis quantization restricts all zero points to be"
+ + " 0, but a non-zero value is observed")
+ zero_point = int(zero_point[0])
+
+ # Scalar - Per-tensor quantization
+ elif tflite_scale.size == 1 and tflite_zero_point.size == 1:
+ scale = float(tflite_scale[0])
+ zero_point = int(tflite_zero_point[0])
+
+ else:
+ raise NotImplementedError(\
+ "Quantized type {} (scale) and {} (zero point) not supported"
+ .format(type(tflite_scale), type(tflite_zero_point)))
+ elif tflite_scale == 0 and tflite_zero_point == 0:
+ # Handle corner case for ops like quantized reshape whose second operand (shape)
+ # has zero scale and zero zero point. This is not used.
+ is_qnn_params_valid = False
+ else:
+ raise NotImplementedError("Quantized type {} not supported"
+ .format(type(tflite_scale)))
+
# Check that the scale and zero points are valid.
- if scale != 0 or zero_point != 0:
+ if is_qnn_params_valid:
qnn_params = dict()
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list
- def get_tensor_value(self, tensor_wrapper):
- """Get tensor buffer value from given tensor wrapper"""
+
+ def get_tensor_type_as_numpy(self, tensor_wrapper):
+ """Returns np.dtype out of TensorType"""
assert isinstance(tensor_wrapper, TensorWrapper)
try:
from tflite.TensorType import TensorType
+ return {TensorType.UINT8: np.uint8,
+ TensorType.INT8: np.int8,
+ TensorType.FLOAT32: np.float32,
+ TensorType.INT32: np.int32,
+ TensorType.INT64: np.int64,
+ TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
except ImportError:
raise ImportError("The tflite package must be installed")
+ except KeyError:
+ raise NotImplementedError("Tensor type '{}' currently not supported"
+ .format(tensor_wrapper.tensor.Type()))
+
+
+ def get_tensor_value(self, tensor_wrapper):
+ """Get tensor buffer value from given tensor wrapper"""
+ assert isinstance(tensor_wrapper, TensorWrapper)
+
+ dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
+ data = tensor_wrapper.buffer.DataAsNumpy()
+
+ if tensor_wrapper.tensor.ShapeLength() != 0:
+ shape = tensor_wrapper.tensor.ShapeAsNumpy()
+ else:
+ shape = []
+
+ return np.frombuffer(data, dtype=dtype).reshape(shape)
- if tensor_wrapper.tensor.Type() == TensorType.UINT8:
- return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
- tensor_wrapper.tensor.ShapeAsNumpy())
- if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
- return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
- tensor_wrapper.tensor.ShapeAsNumpy())
- if tensor_wrapper.tensor.Type() == TensorType.INT32:
- return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
- tensor_wrapper.tensor.ShapeAsNumpy())
- if tensor_wrapper.tensor.Type() == TensorType.INT64:
- return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
- tensor_wrapper.tensor.ShapeAsNumpy())
- if tensor_wrapper.tensor.Type() == TensorType.BOOL:
- return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
- tensor_wrapper.tensor.ShapeAsNumpy())
- raise NotImplementedError("Tensor type {} is currently not supported"
- .format(str(tensor_wrapper.tensor.Type())))
def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
def convert_relu(self, op):
"""Convert TFLite ReLU"""
+ try:
+ from tflite.ActivationFunctionType import ActivationFunctionType
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
-
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
- out = _op.nn.relu(in_expr)
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+
+ if input_tensor.qnn_params:
+ # Quantize a float value to an quantized integer value
+ scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
+ zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
+
+ output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+ out = self.convert_qnn_fused_activation_function(\
+ expr=in_expr,
+ fused_activation_fn=ActivationFunctionType.RELU,
+ scale=scale_val,
+ zero_point=zero_point_val,
+ dtype=output_tensor_type_str)
+ else:
+ out = _op.nn.relu(in_expr)
+
+ if output_tensor.qnn_params:
+ output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+ out = _qnn.op.requantize(out,
+ input_scale=input_tensor.qnn_params['scale'],
+ input_zero_point=input_tensor.qnn_params['zero_point'],
+ output_scale=output_tensor.qnn_params['scale'],
+ output_zero_point=output_tensor.qnn_params['zero_point'],
+ out_dtype=output_tensor_type_str)
return out
def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
+ try:
+ from tflite.ActivationFunctionType import ActivationFunctionType
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
- quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)
- # Get min/max of the input dtype. This will be used to ensure that
- # clip a_min/a_max are not beyond the dtype range.
- input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
- qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
- qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)
-
- out = _op.clip(in_expr,
- a_min=max(qmin, quantize(0)),
- a_max=min(qmax, quantize(6.0)))
+ output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+ out = self.convert_qnn_fused_activation_function(\
+ expr=in_expr,
+ fused_activation_fn=ActivationFunctionType.RELU6,
+ scale=scale_val,
+ zero_point=zero_point_val,
+ dtype=output_tensor_type_str)
else:
out = _op.clip(in_expr, a_min=0, a_max=6)
fully_connected_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = fully_connected_options.FusedActivationFunction()
- # weight tensor type should be UINT8 (quantization) or FLOAT32
+ # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
- assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+ assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
if self.has_expr(weight_tensor.tensor_idx):
params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'
- # weight tensor type should be UINT8 (quantization) or FLOAT32
+ # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
- assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+ assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
in_expr = self.get_expr(input_tensor_idx)
if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params['scale']
- weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
- weight_scale_val = get_scalar_from_constant(weight_scale)
+
+ weight_scale = weight_tensor.qnn_params['scale']
+ # If weight scale is scalar, it is per-tensor quantization
+ if isinstance(weight_scale, float):
+ weight_scale_val = get_scalar_from_constant(weight_scale)
+ else:
+ weight_scale_val = get_tensor_from_constant(weight_scale)
+
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
- out_dtype=output_tensor_type_str)
+ out_dtype=output_tensor_type_str,
+ axis=3)
# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
dtype=output_tensor_type_str)
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)
-
return out
def convert_split(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
+ input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
+ output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
# The output must be quantized
assert output_tensor.qnn_params
- # Quantize the input
- out = self.quantize(in_expr, output_tensor)
+ # TFLite Quantize op can also act as Requantize op
+ if input_tensor_type_str == "float32":
+ out = self.quantize(in_expr, output_tensor)
+ else:
+ out = _qnn.op.requantize(in_expr,
+ input_scale=input_tensor.qnn_params['scale'],
+ input_zero_point=input_tensor.qnn_params['zero_point'],
+ output_scale=output_tensor.qnn_params['scale'],
+ output_zero_point=output_tensor.qnn_params['zero_point'],
+ out_dtype=output_tensor_type_str)
return out
def convert_dequantize(self, op):
else:
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)
-
return expr
"value must be float32/int32"
return np.asscalar(value)
+def get_tensor_from_constant(expr):
+ """ Returns tensor of values from Relay constant node. """
+ assert isinstance(expr, _expr.Constant)
+ value = expr.data.asnumpy()
+ assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
+ "value must be float32/int32"
+ return value
def build_str_map(obj):
"""Build string map of TFLite enum int value
data = np.reshape(x, (1, im_height, im_width, 3))
return data
+
+def pre_processed_image(height, width):
+ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+ img_name = 'elephant-299.jpg'
+ image_url = os.path.join(repo_base, img_name)
+ img_path = download_testdata(image_url, img_name, module='data')
+ image = tf.io.read_file(img_path)
+ image = tf.image.decode_jpeg(image, channels=3)
+ with tf.name_scope('eval_image'):
+ if image.dtype != tf.float32:
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+ image = tf.image.central_crop(image, central_fraction=0.875)
+ # Resize the image to the specified height and width.
+ image = tf.image.resize(image, [height, width],
+ align_corners=False)
+ image = tf.expand_dims(image, axis=0)
+ return image
+
+
def get_real_image_object_detection(im_height, im_width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/'
img_name = 'street_small.jpg'
else:
raise RuntimeError("Unknown object type: %s" % type(o))
+
+def _quantize_keras_model(keras_model, representative_data_gen):
+ """Utility function to quantize a Keras model using TFLite converter."""
+ converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+ converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+ converter.representative_dataset = representative_data_gen
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.uint8
+ converter.inference_output_type = tf.uint8
+ return converter.convert()
+
+
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """
# Convolution
# -----------
+
+def _test_tflite2_quantized_convolution(input_shape, kernel_shape,
+ dilations, strides, padding, data_format):
+ """ One iteration of TFLite2 quantized convolution with given shapes and attributes """
+ data_format = "channels_last" if "NHWC" else "channels_first"
+ data = np.random.uniform(0, 1, input_shape).astype('float32')
+ kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+
+ data_in = tf.keras.layers.Input(shape=data.shape[1:])
+ conv = tf.keras.layers.Conv2D(filters=kernel_shape[3],
+ kernel_size=(kernel_shape[0], kernel_shape[1]),
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation='relu',
+ use_bias=False)(data_in)
+ keras_model = tf.keras.models.Model(data_in, conv)
+ keras_model.layers[1].set_weights([kernel])
+
+ # To create quantized values with dynamic range of activations, needs representative dataset
+ def representative_data_gen():
+ for i in range(1):
+ yield [data]
+
+ tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+ tflite_output = run_tflite_graph(tflite_model_quant, data)
+ tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+ rtol=1e-2, atol=1e-2)
+
+
+def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape,
+ dilations, strides, padding, data_format, depth_multiplier):
+ """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes"""
+ data_format = "channels_last" if "NHWC" else "channels_first"
+ data = np.random.uniform(0, 1, input_shape).astype('float32')
+ kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+
+ data_in = tf.keras.layers.Input(shape=data.shape[1:])
+ conv = tf.keras.layers.DepthwiseConv2D(kernel_size=(kernel_shape[0], kernel_shape[1]),
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation='relu',
+ use_bias=False,
+ depth_multiplier=depth_multiplier)(data_in)
+ keras_model = tf.keras.models.Model(data_in, conv)
+ keras_model.layers[1].set_weights([kernel])
+
+
+ # To create quantized values with dynamic range of activations, needs representative dataset
+ def representative_data_gen():
+ for i in range(1):
+ yield [data]
+
+ tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+ tflite_output = run_tflite_graph(tflite_model_quant, data)
+ tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+ rtol=1e-2, atol=1e-2)
+
+
def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format,
is_depthwise=False, quantized=False):
data_format=data_format)
if quantized:
- # For now only quantized conv2d is supported
- assert not is_depthwise
-
- # Quantized the inputs and feed them to the convolution
- inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
- inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
- out = nn_ops.conv2d(inq_data,
- inq_filter,
- strides=strides,
- padding=padding,
- data_format=data_format)
- out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
-
- # Set the input quantization range
- input_range = {'in_data': (-100, 100)} if quantized else None
-
- # Compare
- compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+ if is_depthwise:
+ # Quantized the inputs and feed them to the convolution
+ inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
+ inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
+ out = nn_ops.depthwise_conv2d_native(inq_data,
+ inq_filter,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
+ out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+
+ # Set the input quantization range
+ input_range = {'in_data': (-100, 100)} if quantized else None
+
+ # Compare
+ compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+ else:
+ # Quantized the inputs and feed them to the convolution
+ inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
+ inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
+ out = nn_ops.conv2d(inq_data,
+ inq_filter,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
+ out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+
+ # Set the input quantization range
+ input_range = {'in_data': (-100, 100)} if quantized else None
+
+ # Compare
+ compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
else:
data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out])
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)
- # depthwise convolution
- _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
- _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
- _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
- _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
- _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
- # dephtwise convolution with single input channel
- _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True)
+ # depthwise convolution
+ _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+ # depthwise convolution with single input channel
+ _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+
+ # TFLite2 quantized convolution testing
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+ _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
+ _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
+ _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
+ _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
+
+ # depthwise convolution
+ _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1],
+ 'SAME', 'NHWC', 1)
+ _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2],
+ 'VALID', 'NHWC', 1)
+ _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2],
+ 'SAME', 'NHWC', 8)
+
#######################################################################
def _test_quantize_dequantize(data):
""" One iteration of quantize and dequantize """
- # Define a dummy model
+ # Keras model to force TFLite converter to insert 2 TFLite quantize ops.
+ # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
+ # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
data_in = tf.keras.layers.Input(shape=data.shape[1:])
- act_func = tf.keras.layers.Activation('linear')
- keras_model = tf.keras.models.Model(data_in, act_func(data_in))
-
- # Load the model
- converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+ relu = tf.keras.layers.ReLU()(data_in)
+ add = tf.keras.layers.Add()([data_in, relu])
+ concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
+ keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
+ input_name = data_in.name.split(":")[0]
# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
- for i in range(100):
+ for i in range(1):
yield [data]
- converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
- converter.representative_dataset = representative_data_gen
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.uint8
- converter.inference_output_type = tf.uint8
-
- # Convert the model to TensorFlow Lite format
- tflite_model_quant = converter.convert()
+ tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
+ tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
- rtol=1e-5, atol=1e-5)
+ rtol=1e-5, atol=1e-2)
def test_forward_quantize_dequantize():
""" Quantize Dequantize """
data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
- if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
_test_quantize_dequantize(data)
# ReLu
# ----
-def _test_relu(data):
+def _test_relu(data, quantized=False):
""" One iteration of ReLU """
- with tf.Graph().as_default():
- in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
- out = nn_ops.relu(in_data)
- compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+ if quantized:
+ if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'):
+ pytest.skip("Testcase requires tflite version >= 2.1.0")
+ data_in = tf.keras.layers.Input(shape=data.shape[1:])
+ relu = tf.keras.layers.ReLU()(data_in)
+ keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
+ input_name = data_in.name.split(":")[0]
+
+ # To create quantized values with dynamic range of activations, needs representative dataset
+ def representative_data_gen():
+ for i in range(1):
+ yield [data]
+
+ tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+ tflite_output = run_tflite_graph(tflite_model_quant, data)
+ tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+ rtol=1e-5, atol=1e-5)
+ else:
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = nn_ops.relu(in_data)
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
+ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True)
#######################################################################
# ReLU6
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+def test_forward_tflite2_qnn_resnet50():
+ """Test the Quantized TFLite version 2.1.0 Resnet50 model."""
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+ tflite_model_file = download_testdata(
+ "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite",
+ "resnet_50_quantized.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ data = pre_processed_image(224, 224)
+
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tflite_predictions = np.squeeze(tflite_output)
+ tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+ tvm_predictions = np.squeeze(tvm_output)
+ tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+ tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
+def test_forward_tflite2_qnn_inception_v1():
+ """Test the Quantized TFLite version 2.1.0 Inception V1 model."""
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+ tflite_model_file = download_testdata(
+ "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite",
+ "inception_v1_quantized.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ data = pre_processed_image(224, 224)
+
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tflite_predictions = np.squeeze(tflite_output)
+ tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+ tvm_predictions = np.squeeze(tvm_output)
+ tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+ tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
+def test_forward_tflite2_qnn_mobilenet_v2():
+ """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model."""
+ if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+ tflite_model_file = download_testdata(
+ "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite",
+ "mobilenet_v2_quantized.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ data = pre_processed_image(224, 224)
+
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tflite_predictions = np.squeeze(tflite_output)
+ tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+ tvm_predictions = np.squeeze(tvm_output)
+ tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+ tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
#######################################################################
# Quantized SSD Mobilenet
# -----------------------
#with Tflite 1.15.2
test_forward_qnn_mobilenet_v3_net()
test_forward_qnn_coco_ssd_mobilenet_v1()
+
+ # TFLite 2.1.0 quantized tests
+ test_forward_tflite2_qnn_resnet50()
+ test_forward_tflite2_qnn_inception_v1()
+ test_forward_tflite2_qnn_mobilenet_v2()