[Frontend][Tensorflow] Gather nd bug fix for one dim support in tensorflow (#5588)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Tue, 19 May 2020 19:45:02 +0000 (01:15 +0530)
committerGitHub <noreply@github.com>
Tue, 19 May 2020 19:45:02 +0000 (12:45 -0700)
* [Frontend][Tensorflow] Gather_nd one dim support added

* Test case added

* Doc error handled

* Review comment handled: reverting new attr introduced

* Check added at mxnet frontend

* Doc error handled

* TFLite test case failure resolved

python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/op_attrs.py
src/relay/op/tensor/transform.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/frontend/tflite/test_forward.py
topi/include/topi/transform.h

index 4c3144c..e6384f7 100644 (file)
@@ -693,6 +693,10 @@ def _mx_take(inputs, attrs):
     axis = attrs.get_int("axis", 0)
     return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
 
+def _mx_gather_nd(inputs, attrs):
+    assert len(inputs) == 2
+    assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions"
+    return _op.gather_nd(inputs[0], inputs[1])
 
 def _mx_reverse(inputs, attrs):
     assert len(inputs) == 1
@@ -1770,7 +1774,6 @@ _identity_list = [
     "zeros_like",
     "ones_like",
     "where",
-    "gather_nd",
     "cos",
     "cosh",
     "sin",
@@ -1918,6 +1921,7 @@ _convert_map = {
     "pad"           : _mx_pad,
     "Pad"           : _mx_pad,
     "take"          : _mx_take,
+    "gather_nd"     : _mx_gather_nd,
     "reverse"       : _mx_reverse,
     "squeeze"       : _mx_squeeze,
     "broadcast_axis": _mx_broadcast_axis,
index a1c73ef..fee213c 100644 (file)
@@ -189,7 +189,6 @@ class TransposeAttrs(Attrs):
 class ReshapeAttrs(Attrs):
     """Attributes for transform.reshape"""
 
-
 @tvm._ffi.register_object("relay.attrs.TakeAttrs")
 class TakeAttrs(Attrs):
     """Attributes for transform.take"""
index 892f3a4..8ec094e 100644 (file)
@@ -2248,6 +2248,9 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   Array<IndexExpr> oshape;
   for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
   for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
+  if (oshape.size() == 0) {
+    oshape.push_back(tir::make_const(DataType::Int(32), 1));
+  }
   reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
index 9dd8506..8113271 100644 (file)
@@ -608,7 +608,7 @@ def test_forward_take():
     verify((3,4), [-1, 5], 1, mode="wrap")
 
 def test_forward_gather_nd():
-    def verify(xshape, yshape, y_data):
+    def verify(xshape, yshape, y_data, error=False):
         x_data = np.random.uniform(size=xshape).astype("float32")
         ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
         mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
@@ -618,10 +618,12 @@ def test_forward_gather_nd():
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x_data, y_data)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
     verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
     verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
     verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
     verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
+    verify((1, 4), (1, 1), [[0]])
 
 def test_forward_bilinear_resize():
     # add tests including scale_height and scale_width when mxnet is updated to version 1.5
index c3313b6..c6a285c 100644 (file)
@@ -1379,7 +1379,7 @@ def test_forward_truncatemod():
 
 
 #######################################################################
-# Gather, GatherV2, GatherNd
+# Gather, GatherV2
 # --------------------------
 
 def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
@@ -1418,16 +1418,32 @@ def test_forward_gather():
     _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
     _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
 
+#######################################################################
+# GatherND
+# --------------------------
 
-def test_forward_gather_nd():
+def _test_gather_nd(ip_shape, indice_value, dtype):
     """test operator GatherNd"""
-    np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
+    np_data = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
     tf.reset_default_graph()
     with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
-        tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
+        in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+        tf.gather_nd(in_data, indices=indice_value, name="gather_nd")
         compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
 
+def test_forward_gather_nd():
+    """test operator GatherNd"""
+    _test_gather_nd((2, 2), [[0, 0], [1, 1]], 'float32')
+    _test_gather_nd((2, 2, 2), [[1, 0, 0], [0, 0, 0]], 'float32')
+    _test_gather_nd((4,), [1], 'float32')
+    _test_gather_nd((4,), [1], 'int32')
+    _test_gather_nd((1, 4), [0, 3], 'int32')
+    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'int32')
+    _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'float32')
+    _test_gather_nd((3, 3, 3),  [[[1, 0]]], 'int32')
+    _test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32')
+    _test_gather_nd((4, 3, 5, 6),  [[2, 1, 0, 0]], 'float32')
+    _test_gather_nd((3, 3, 3), [[[2, 1]]], 'int32')
 
 #######################################################################
 # BiasAdd
index 2319904..7a8437a 100644 (file)
@@ -383,6 +383,18 @@ def test_forward_gather_nd():
         np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
         np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
     )
+    _test_gather_nd(
+        np.reshape(np.arange(4), [4]).astype('float32'),
+        np.asarray([1]).astype('int32')
+    )
+    _test_gather_nd(
+        np.reshape(np.arange(4), [1, 4]).astype('float32'),
+        np.asarray([0]).astype('int32')
+    )
+    _test_gather_nd(
+        np.reshape(np.arange(4), [1, 4]).astype('float32'),
+        np.asarray([0, 3]).astype('int32')
+    )
 
 #######################################################################
 # StridedSlice
index e21fc2a..400cd1e 100644 (file)
@@ -996,7 +996,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
                         std::string tag = kInjective) {
   size_t ndim_d = data->shape.size();
   size_t ndim_i = indices->shape.size();
-  CHECK_GT(ndim_i, 1) << "indices tensor must have at least 2 dimensions";
+  CHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
   size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
   CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
                                  << "than dimensions of data tensor";
@@ -1027,6 +1027,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
             real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
           }
         }
+        if (real_indices.size() == ndim_d) {
+          return data(real_indices);
+        }
         for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
           real_indices.push_back(out_index[i]);
         }