'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
+ 'GATHER_ND' : self.convert_gather_nd,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
out = _op.take(data, indices, axis=axis, mode="fast")
return out
+ def convert_gather_nd(self, op):
+ """Method to Convert TFLite GATHER_ND operator"""
+ try:
+ from tflite.TensorType import TensorType
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ for t in input_tensors:
+ assert not t.qnn_params, "Quantized input is not expected."
+
+ data = self.get_tensor_expr(input_tensors[0])
+ indices = self.get_tensor_expr(input_tensors[1])
+
+ indices_type = input_tensors[1].tensor.Type()
+ assert indices_type in (TensorType.INT32, TensorType.INT64)
+
+ indices_dims = len(_infer_shape(indices))
+ indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1)))
+
+ out = _op.gather_nd(data, indices_t)
+ return out
+
def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
#######################################################################
+# Gather_ND
+# ---------
+
+def _test_gather_nd(data, indices):
+ """ One iteration of GATHER_ND """
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
+ indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
+ name="indices")
+ out = tf.gather_nd(in_data, indices_data)
+
+ compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
+ [in_data, indices_data], [out])
+
+def test_forward_gather_nd():
+ """ GATHER_ND """
+ _test_gather_nd(
+ np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'),
+ np.asarray([[0, 1], [1, 0]]).astype('int32')
+ )
+ _test_gather_nd(
+ np.reshape(np.arange(30), [5, 6]).astype('int32'),
+ np.asarray([[1, 2]]).astype('int32')
+ )
+ _test_gather_nd(
+ np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
+ np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
+ )
+
+#######################################################################
# StridedSlice
# ------------
test_forward_slice()
test_forward_topk()
test_forward_gather()
+ test_forward_gather_nd()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()