[TFLITE]Hard Swish & MobilnetV3 model testing (#5239)
authorSamuel <siju.samuel@huawei.com>
Tue, 7 Apr 2020 08:14:02 +0000 (13:44 +0530)
committerGitHub <noreply@github.com>
Tue, 7 Apr 2020 08:14:02 +0000 (16:14 +0800)
* [TFLITE]Hard Swish & MobilnetV3 model testing

* CI Failure addressed

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 7f7ae30..caf8f92 100644 (file)
@@ -84,6 +84,7 @@ class OperatorConverter(object):
             'FULLY_CONNECTED': self.convert_fully_connected,
             'GREATER_EQUAL': self.convert_greater_equal,
             'GREATER': self.convert_greater,
+            'HARD_SWISH': self.convert_hard_swish,
             'L2_NORMALIZATION': self.convert_l2_normalization,
             'LESS_EQUAL': self.convert_less_equal,
             'LESS': self.convert_less,
@@ -595,6 +596,42 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_hard_swish(self, op):
+        """Convert TFLite Hard swish"""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+        assert isinstance(op, Operator)
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+
+        def _relu6(data):
+            return _op.tensor.clip(data, 0.0, 6.0)
+
+        def _hard_swish(data):
+            return data * _relu6(data + relay.const(3.0)) / relay.const(6.0)
+
+        # Dequantize if the input is quantized.
+        if input_tensor.qnn_params:
+            in_expr = self.dequantize(in_expr, input_tensor)
+
+        # Perform hardswish
+        out = _hard_swish(in_expr)
+
+        # Go back to integer dataype if the original operator was quantized.
+        if output_tensor.qnn_params:
+            out = self.quantize(out, output_tensor)
+
+        return out
+
     def convert_concatenation(self, op):
         """Convert TFLite concatenation"""
         try:
index 42726b7..831b021 100644 (file)
@@ -1626,6 +1626,26 @@ def test_forward_mobilenet_v2():
                                 rtol=1e-5, atol=1e-5)
 
 #######################################################################
+# Mobilenet V3
+# ------------
+
+def test_forward_mobilenet_v3():
+    """Test the Mobilenet V3 TF Lite model."""
+    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
+    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+        return
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
+        "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                rtol=1e-5, atol=1e-5)
+
+#######################################################################
 # Inception
 # ---------
 
@@ -1724,6 +1744,35 @@ def test_forward_qnn_mobilenet_v2_net():
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 #######################################################################
+# Mobilenet V3 Quantized
+# ----------------------
+
+def test_forward_qnn_mobilenet_v3_net():
+    """Test the Quantized TFLite Mobilenet V3 model."""
+    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
+    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
+        return
+
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
+        "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+
+    # Test image. Checking the labels because the requantize implementation is different between
+    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+    # labels. Also, giving a real image, instead of random inputs.
+    data = get_real_image(224, 224)
+
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tflite_predictions = np.squeeze(tflite_output)
+    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
+    tvm_predictions = np.squeeze(tvm_output)
+    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+#######################################################################
 # SSD Mobilenet
 # -------------
 
@@ -1831,6 +1880,7 @@ if __name__ == '__main__':
     # End to End
     test_forward_mobilenet_v1()
     test_forward_mobilenet_v2()
+    test_forward_mobilenet_v3()
     test_forward_inception_v3_net()
     test_forward_inception_v4_net()
     test_forward_ssd_mobilenet_v1()
@@ -1840,3 +1890,4 @@ if __name__ == '__main__':
     test_forward_qnn_inception_v1_net()
     test_forward_qnn_mobilenet_v1_net()
     test_forward_qnn_mobilenet_v2_net()
+    test_forward_qnn_mobilenet_v3_net()