[FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)
authorMORITA Kazutaka <morita.kazutaka@gmail.com>
Fri, 10 Apr 2020 14:47:53 +0000 (23:47 +0900)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 14:47:53 +0000 (07:47 -0700)
* [FRONTEND][TENSORFLOW] Fix gather_nd indices

* retrigger CI

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 77dbcb5..8a72423 100644 (file)
@@ -1127,9 +1127,11 @@ def _gather():
 def _gather_nd():
     """GatherNd"""
     def _impl(inputs, attr, params, mod):
+        indices_dims = len(_infer_shape(inputs[1], mod))
+        indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1)))
         return AttrCvt(op_name="gather_nd",
                        ignores=['Tindices', 'Tparams',\
-                                'Taxis', '_class'])(inputs, attr)
+                                'Taxis', '_class'])([inputs[0], indices], attr)
     return _impl
 
 def _stridedSlice():
index 35a3466..fdb8912 100644 (file)
@@ -1365,11 +1365,11 @@ def test_forward_gather():
 
 def test_forward_gather_nd():
     """test operator GatherNd"""
-    np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
+    np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
     tf.reset_default_graph()
     with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
-        tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
+        in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
+        tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
         compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')