[TFLITE]TOP_K op parser support (#5051)
authorSamuel <siju.samuel@huawei.com>
Mon, 30 Mar 2020 20:51:01 +0000 (02:21 +0530)
committerGitHub <noreply@github.com>
Mon, 30 Mar 2020 20:51:01 +0000 (13:51 -0700)
* [TFLITE]TOP_K op parser support

* Testcase updated

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index aa51570..7f7ae30 100644 (file)
@@ -129,6 +129,7 @@ class OperatorConverter(object):
             'TAN': self.convert_tan,
             'TANH':self.convert_tanh,
             'TILE': self.convert_tile,
+            'TOPK_V2': self.convert_topk_v2,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
             'TRANSPOSE': self.convert_transpose,
             'UNPACK': self.convert_unpack,
@@ -1550,6 +1551,24 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_topk_v2(self, op):
+        """ Convert TFLite TOPK_v2 """
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+        input_tensor = input_tensors[0]
+        input_tensor_idx = input_tensor.tensor_idx
+        in_expr = self.get_expr(input_tensor_idx)
+        k = self.get_tensor_value(input_tensors[1])
+        out = _op.topk(in_expr, int(k))
+
+        return out
+
     def convert_pool2d(self, op, pool_type):
         """pool2d implementation."""
         try:
index 5f0c444..42726b7 100644 (file)
@@ -273,6 +273,24 @@ def test_forward_slice():
         _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
 
 #######################################################################
+# Topk
+# ----
+def _test_topk(in_shape, k=1):
+    """ One iteration of TOPK """
+    data = np.random.uniform(size=in_shape).astype('float32')
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        out = nn_ops.top_k(in_data, k, name='TopK')
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])
+
+def test_forward_topk():
+    """ TOPK """
+    _test_topk((3,), 1)
+    _test_topk((3,), 3)
+    _test_topk((3, 5, 7), 3)
+    _test_topk((3, 5, 7), 3)
+
+#######################################################################
 # transpose
 # ---------
 
@@ -1775,6 +1793,7 @@ if __name__ == '__main__':
     test_all_resize()
     test_forward_squeeze()
     test_forward_slice()
+    test_forward_topk()
     test_forward_depthtospace()
     test_forward_spacetodepth()