'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'TILE': self.convert_tile,
+ 'TOPK_V2': self.convert_topk_v2,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
return out
+ def convert_topk_v2(self, op):
+ """ Convert TFLite TOPK_v2 """
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ input_tensor = input_tensors[0]
+ input_tensor_idx = input_tensor.tensor_idx
+ in_expr = self.get_expr(input_tensor_idx)
+ k = self.get_tensor_value(input_tensors[1])
+ out = _op.topk(in_expr, int(k))
+
+ return out
+
def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
#######################################################################
+# Topk
+# ----
+def _test_topk(in_shape, k=1):
+ """ One iteration of TOPK """
+ data = np.random.uniform(size=in_shape).astype('float32')
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = nn_ops.top_k(in_data, k, name='TopK')
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])
+
+def test_forward_topk():
+ """ TOPK """
+ _test_topk((3,), 1)
+ _test_topk((3,), 3)
+ _test_topk((3, 5, 7), 3)
+ _test_topk((3, 5, 7), 3)
+
+#######################################################################
# transpose
# ---------
test_all_resize()
test_forward_squeeze()
test_forward_slice()
+ test_forward_topk()
test_forward_depthtospace()
test_forward_spacetodepth()