'FULLY_CONNECTED': self.convert_fully_connected,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
+ 'HARD_SWISH': self.convert_hard_swish,
'L2_NORMALIZATION': self.convert_l2_normalization,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
return out
+ def convert_hard_swish(self, op):
+ """Convert TFLite Hard swish"""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+ assert isinstance(op, Operator)
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+
+ def _relu6(data):
+ return _op.tensor.clip(data, 0.0, 6.0)
+
+ def _hard_swish(data):
+ return data * _relu6(data + relay.const(3.0)) / relay.const(6.0)
+
+ # Dequantize if the input is quantized.
+ if input_tensor.qnn_params:
+ in_expr = self.dequantize(in_expr, input_tensor)
+
+ # Perform hardswish
+ out = _hard_swish(in_expr)
+
+ # Go back to integer dataype if the original operator was quantized.
+ if output_tensor.qnn_params:
+ out = self.quantize(out, output_tensor)
+
+ return out
+
def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
rtol=1e-5, atol=1e-5)
#######################################################################
+# Mobilenet V3
+# ------------
+
+def test_forward_mobilenet_v3():
+ """Test the Mobilenet V3 TF Lite model."""
+ # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
+ if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+ return
+ tflite_model_file = tf_testing.get_workload_official(
+ "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
+ "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+ data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+ rtol=1e-5, atol=1e-5)
+
+#######################################################################
# Inception
# ---------
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
#######################################################################
+# Mobilenet V3 Quantized
+# ----------------------
+
+def test_forward_qnn_mobilenet_v3_net():
+ """Test the Quantized TFLite Mobilenet V3 model."""
+ # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
+ if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+ return
+
+ tflite_model_file = tf_testing.get_workload_official(
+ "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
+ "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ # Test image. Checking the labels because the requantize implementation is different between
+ # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+ # labels. Also, giving a real image, instead of random inputs.
+ data = get_real_image(224, 224)
+
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tflite_predictions = np.squeeze(tflite_output)
+ tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+ tvm_predictions = np.squeeze(tvm_output)
+ tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+ tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+#######################################################################
# SSD Mobilenet
# -------------
# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
+ test_forward_mobilenet_v3()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net()
+ test_forward_qnn_mobilenet_v3_net()