[TFLITE]GATHER_ND (#5508)
authorDhruva Ray <dhruvaray@gmail.com>
Mon, 18 May 2020 02:48:17 +0000 (08:18 +0530)
committerGitHub <noreply@github.com>
Mon, 18 May 2020 02:48:17 +0000 (11:48 +0900)
Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 5a645c6..cb10ce5 100644 (file)
@@ -86,6 +86,7 @@ class OperatorConverter(object):
             'FLOOR': self.convert_floor,
             'FULLY_CONNECTED': self.convert_fully_connected,
             'GATHER': self.convert_gather,
+            'GATHER_ND' : self.convert_gather_nd,
             'GREATER_EQUAL': self.convert_greater_equal,
             'GREATER': self.convert_greater,
             'HARD_SWISH': self.convert_hard_swish,
@@ -1113,6 +1114,31 @@ class OperatorConverter(object):
         out = _op.take(data, indices, axis=axis, mode="fast")
         return out
 
+    def convert_gather_nd(self, op):
+        """Method to Convert TFLite GATHER_ND operator"""
+        try:
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        for t in input_tensors:
+            assert not t.qnn_params, "Quantized input is not expected."
+
+        data = self.get_tensor_expr(input_tensors[0])
+        indices = self.get_tensor_expr(input_tensors[1])
+
+        indices_type = input_tensors[1].tensor.Type()
+        assert indices_type in (TensorType.INT32, TensorType.INT64)
+
+        indices_dims = len(_infer_shape(indices))
+        indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1)))
+
+        out = _op.gather_nd(data, indices_t)
+        return out
+
     def convert_strided_slice(self, op):
         """Method to Convert TFLite STRIDED_SLICE operator.
            NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
index 9963479..2319904 100644 (file)
@@ -355,6 +355,36 @@ def test_forward_gather():
         _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
 
 #######################################################################
+# Gather_ND
+# ---------
+
+def _test_gather_nd(data, indices):
+    """ One iteration of GATHER_ND """
+    with tf.Graph().as_default():
+        in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
+        indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
+                                        name="indices")
+        out = tf.gather_nd(in_data, indices_data)
+
+        compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
+                                  [in_data, indices_data], [out])
+
+def test_forward_gather_nd():
+    """ GATHER_ND """
+    _test_gather_nd(
+        np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'),
+        np.asarray([[0, 1], [1, 0]]).astype('int32')
+    )
+    _test_gather_nd(
+        np.reshape(np.arange(30), [5, 6]).astype('int32'),
+        np.asarray([[1, 2]]).astype('int32')
+    )
+    _test_gather_nd(
+        np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
+        np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
+    )
+
+#######################################################################
 # StridedSlice
 # ------------
 
@@ -2217,6 +2247,7 @@ if __name__ == '__main__':
     test_forward_slice()
     test_forward_topk()
     test_forward_gather()
+    test_forward_gather_nd()
     test_forward_stridedslice()
     test_forward_depthtospace()
     test_forward_spacetodepth()