[topi][relay] Add operation gather to relay. (#5716)
authornotoraptor <notoraptor@users.noreply.github.com>
Fri, 12 Jun 2020 15:24:56 +0000 (11:24 -0400)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 15:24:56 +0000 (08:24 -0700)
14 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/relay/attrs/transform.h
python/tvm/relay/op/_transform.py
python/tvm/relay/op/op_attrs.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_op_level3.py
topi/include/topi/transform.h
topi/python/topi/testing/__init__.py
topi/python/topi/testing/gather_python.py [new file with mode: 0644]
topi/python/topi/transform.py
topi/src/transform.cc
topi/tests/python/test_topi_transform.py

index 960d946..65f2375 100644 (file)
@@ -55,6 +55,7 @@ List of operators
    topi.concatenate
    topi.split
    topi.take
+   topi.gather
    topi.gather_nd
    topi.full
    topi.full_like
@@ -160,6 +161,7 @@ topi
 .. autofunction:: topi.concatenate
 .. autofunction:: topi.split
 .. autofunction:: topi.take
+.. autofunction:: topi.gather
 .. autofunction:: topi.gather_nd
 .. autofunction:: topi.full
 .. autofunction:: topi.full_like
index b3fdf1c..cef96ef 100644 (file)
@@ -120,6 +120,7 @@ This level enables additional math and transform operators.
    tvm.relay.zeros_like
    tvm.relay.ones
    tvm.relay.ones_like
+   tvm.relay.gather
    tvm.relay.gather_nd
    tvm.relay.full
    tvm.relay.full_like
index b0d7de5..cbc6034 100644 (file)
@@ -101,6 +101,16 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
   }
 };
 
+struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
+  Integer axis;
+
+  TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
+    TVM_ATTR_FIELD(axis)
+        .set_default(NullValue<Integer>())
+        .describe("The axis over which to select values.");
+  }
+};
+
 struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
   Integer axis;
   std::string mode;
index b1cfe50..f134b82 100644 (file)
@@ -51,6 +51,7 @@ _reg.register_injective_schedule("take")
 _reg.register_injective_schedule("transpose")
 _reg.register_injective_schedule("stack")
 _reg.register_injective_schedule("_contrib_reverse_reshape")
+_reg.register_injective_schedule("gather")
 _reg.register_injective_schedule("gather_nd")
 _reg.register_injective_schedule("sequence_mask")
 _reg.register_injective_schedule("one_hot")
index 8a7ab48..429c4f1 100644 (file)
@@ -189,6 +189,10 @@ class TransposeAttrs(Attrs):
 class ReshapeAttrs(Attrs):
     """Attributes for transform.reshape"""
 
+@tvm._ffi.register_object("relay.attrs.GatherAttrs")
+class GatherAttrs(Attrs):
+    """Attributes for transform.gather"""
+
 @tvm._ffi.register_object("relay.attrs.TakeAttrs")
 class TakeAttrs(Attrs):
     """Attributes for transform.take"""
index 0458b9a..05958fc 100644 (file)
@@ -800,6 +800,43 @@ def reverse_reshape(data, newshape):
     return _make._contrib_reverse_reshape(data, list(newshape))
 
 
+def gather(data, axis, indices):
+    """Gather values along given axis from given indices.
+
+    E.g. for a 3D tensor, output is computed as:
+
+    .. code-block:: python
+
+        out[i][j][k] = data[indices[i][j][k]][j][k]  # if axis == 0
+        out[i][j][k] = data[i][indices[i][j][k]][k]  # if axis == 1
+        out[i][j][k] = data[i][j][indices[i][j][k]]  # if axis == 2
+
+    ``indices`` must have same shape as ``data``, except at dimension ``axis``
+    which must just be not null. Output will have same shape as ``indices``.
+
+    Parameters
+    ----------
+    data: relay.Expr
+        The input data to the operator.
+
+    axis: int
+        The axis along which to index.
+
+    indices: relay.Expr
+        The indices of values to gather.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [[1, 2], [3, 4]]
+        axis = 1
+        indices = [[0, 0], [1, 0]]
+        relay.gather(data, axis, indices) = [[1, 1], [4, 3]]
+    """
+    return _make.gather(data, axis, indices)
+
+
 def gather_nd(data, indices):
     """Gather elements or slices from data and store to a tensor whose shape is
     defined by indices.
index 222a38d..2a7e4e2 100644 (file)
@@ -2397,6 +2397,88 @@ example below::
     .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// gather operator
+TVM_REGISTER_NODE_TYPE(GatherAttrs);
+
+bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // `types` contains: [data, indices, result]
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* indices = types[1].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "Gather: expect input data type to be TensorType but get " << types[0];
+    return false;
+  }
+  if (indices == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "Gather: expect indices type to be TensorType but get " << types[1];
+    return false;
+  }
+  CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
+  const auto param = attrs.as<GatherAttrs>();
+  CHECK(param != nullptr);
+  CHECK(param->axis.defined());
+
+  const auto ndim_data = data->shape.size();
+  const auto ndim_indices = indices->shape.size();
+  int axis = param->axis->value;
+  CHECK_EQ(ndim_data, ndim_indices);
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, ndim_data);
+
+  std::vector<IndexExpr> oshape;
+  oshape.reserve(ndim_data);
+  for (size_t i = 0; i < ndim_data; ++i) {
+    if (i == (size_t)axis) {
+      const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
+      CHECK_GE(*indice_shape_i, 1);
+    } else {
+      CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
+    }
+    oshape.emplace_back(indices->shape[i]);
+  }
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
+  return true;
+}
+
+Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
+  const auto* param = attrs.as<GatherAttrs>();
+  return {topi::gather(inputs[0], param->axis, inputs[1])};
+}
+
+Expr MakeGather(Expr data, Integer axis, Expr indices) {
+  auto attrs = make_object<GatherAttrs>();
+  attrs->axis = std::move(axis);
+  static const Op& op = Op::Get("gather");
+  return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather);
+
+RELAY_REGISTER_OP("gather")
+    .describe(R"code(Gather values along given axis from given indices.
+
+E.g. for a 3D tensor, output is computed as:
+
+       out[i][j][k] = data[indices[i][j][k]][j][k]  # if axis == 0
+       out[i][j][k] = data[i][indices[i][j][k]][k]  # if axis == 1
+       out[i][j][k] = data[i][j][indices[i][j][k]]  # if axis == 2
+
+``indices`` must have same shape as ``data``, except at dimension ``axis``
+which must just be not null. Output will have same shape as ``indices``.
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<GatherAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input data to the operator.")
+    .add_argument("indices", "Tensor", "The indices of values to gather.")
+    .set_support_level(3)
+    .add_type_rel("Gather", GatherRel)
+    .set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 // gather_nd operator
 bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                  const TypeReporter& reporter) {
index d778312..f50a692 100644 (file)
@@ -711,6 +711,58 @@ def test_scatter():
     verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
 
 
+def test_gather():
+    def verify_gather(data, axis, indices, ref_res):
+        data = np.asarray(data, dtype='float32')
+        indices = np.asarray(indices, dtype='int32')
+        ref_res = np.asarray(ref_res)
+
+        d = relay.var("x", relay.TensorType(data.shape, "float32"))
+        i = relay.var("y", relay.TensorType(indices.shape, "int32"))
+        z = relay.gather(d, axis, i)
+
+        func = relay.Function([d, i], z)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(data, indices)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
+                                            rtol=1e-5)
+
+    verify_gather([[1, 2], [3, 4]],
+                  1,
+                  [[0, 0], [1, 0]],
+                  [[1, 1], [4, 3]])
+    verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
+                  0,
+                  [[[1, 0, 1], [1, 1, 0]]],
+                  [[[6, 1, 8], [9, 10, 5]]])
+    verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448],
+                    [0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]],
+                   [[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502],
+                    [0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]],
+                  1,
+                  [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
+                  [[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
+                   [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]])
+    verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818],
+                    [0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]],
+                   [[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084],
+                    [0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]],
+                  2,
+                  [[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
+                   [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]],
+                  [[[1.6986, 1.6986, 0.3050, 1.6986],
+                    [0.7020, 0.7020, -2.1818, -2.1818],
+                    [-0.5773, -0.9912, -0.5773, -0.9912],
+                    [-1.0720, -1.0720, -1.3915, 0.0835]],
+                   [[0.1694, 0.1694, -0.6091, -0.6539],
+                    [0.5084, 0.5084, -0.1218, -0.5234],
+                    [-1.9537, -2.0078, 0.2374, 0.2374],
+                    [-0.5700, 0.1558, -0.5700, 0.1558]]])
+
+
 def test_gather_nd():
     def verify_gather_nd(xshape, yshape, y_data):
         x = relay.var("x", relay.TensorType(xshape, "float32"))
index e830e09..7947967 100644 (file)
@@ -989,6 +989,54 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
 }
 
 /*!
+ * \brief Gather values along given axis from given indices.
+ *
+ * \param data The input data to the operator.
+ * \param axis The axis along which to index.
+ * \param indices The indices of values to gather.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the gather operation
+ */
+inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
+                     std::string name = "T_gather", std::string tag = kInjective) {
+  size_t ndim_d = data->shape.size();
+  size_t ndim_i = indices->shape.size();
+  CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
+  CHECK_EQ(ndim_d, ndim_i);
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, ndim_d);
+  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
+  CHECK_GE(indices_dim_i, 1);
+  CHECK(indices->dtype.is_int());
+
+  Array<PrimExpr> out_shape;
+  for (size_t i = 0; i < ndim_i; ++i) {
+    out_shape.push_back(indices->shape[i]);
+  }
+
+  return compute(
+      out_shape,
+      [&](const Array<Var>& out_index) {
+        Array<PrimExpr> indices_position;
+        for (size_t i = 0; i < ndim_i; ++i) {
+          indices_position.push_back(out_index[i]);
+        }
+        Array<PrimExpr> real_indices;
+        for (size_t i = 0; i < ndim_i; ++i) {
+          if (i == (size_t)axis) {
+            real_indices.push_back(indices(indices_position));
+          } else {
+            real_indices.push_back(indices_position[i]);
+          }
+        }
+        return data(real_indices);
+      },
+      name, tag);
+}
+
+/*!
  * \brief Gather elements from a n-dimension array.
  *
  * \param data The source array.
index bd9825a..70ee8e9 100644 (file)
@@ -43,6 +43,7 @@ from .roi_align_python import roi_align_nchw_python
 from .roi_pool_python import roi_pool_nchw_python
 from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
+from .gather_python import gather_python
 from .gather_nd_python import gather_nd_python
 from .strided_slice_python import strided_slice_python, strided_set_python
 from .batch_matmul import batch_matmul
diff --git a/topi/python/topi/testing/gather_python.py b/topi/python/topi/testing/gather_python.py
new file mode 100644 (file)
index 0000000..0f3573c
--- /dev/null
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""gather in python"""
+import numpy as np
+
+def gather_python(data, axis, indices):
+    """ Python version of Gather operator
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        Numpy array
+
+    axis: int
+        integer
+
+    indices : numpy.ndarray
+        Numpy array
+
+    Returns
+    -------
+    b_np : numpy.ndarray
+        Numpy array
+    """
+    shape_indices = indices.shape
+    out = np.zeros(shape_indices, dtype=data.dtype)
+    for index in np.ndindex(*shape_indices):
+        new_index = list(index)
+        new_index[axis] = indices[index]
+        out[index] = data[tuple(new_index)]
+    return out
index 5a0bf11..f1bcccd 100644 (file)
@@ -374,6 +374,38 @@ def take(a, indices, axis=None, mode="clip"):
     return cpp.take(a, indices, int(axis), mode)
 
 
+def gather(data, axis, indices):
+    """Gather values along given axis from given indices.
+
+    E.g. for a 3D tensor, output is computed as:
+
+    .. code-block:: python
+
+        out[i][j][k] = data[indices[i][j][k]][j][k]  # if axis == 0
+        out[i][j][k] = data[i][indices[i][j][k]][k]  # if axis == 1
+        out[i][j][k] = data[i][j][indices[i][j][k]]  # if axis == 2
+
+    ``indices`` must have same shape as ``data``, except at dimension ``axis``
+    which must just be not null. Output will have same shape as ``indices``.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis: int
+        The axis along which to index.
+
+    indices : tvm.te.Tensor
+        The indices of the values to extract.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    return cpp.gather(data, axis, indices)
+
+
 def gather_nd(a, indices):
     """Gather elements from a n-dimension array..
 
index 5300973..2791ff7 100644 (file)
@@ -112,6 +112,10 @@ TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) {
   *rv = tile(args[0], args[1]);
 });
 
+TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = gather(args[0], args[1], args[2]);
+});
+
 TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
   *rv = gather_nd(args[0], args[1]);
 });
index 47ea8d7..96df101 100644 (file)
@@ -402,6 +402,35 @@ def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
     for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
+def verify_gather(data, axis, indices):
+    data = np.asarray(data)
+    indices = np.asarray(indices)
+
+    var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data")
+    var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices")
+    out_tensor = topi.gather(var_data, axis, var_indices)
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.testing.get_injective_schedule(device)(out_tensor)
+
+        func = tvm.build(s, [var_data, var_indices, out_tensor] , device, name="gather")
+        out_npys = topi.testing.gather_python(data, axis, indices)
+
+        data_nd = tvm.nd.array(data, ctx)
+        indices_nd = tvm.nd.array(indices, ctx)
+        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=data.dtype.name)
+        func(data_nd, indices_nd, out_nd)
+        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
+
+    for device in get_all_backend():
+        check_device(device)
+
 def verify_gather_nd(src_shape, indices_src, indices_dtype):
     src_dtype = "float32"
     indices_src = np.array(indices_src, dtype=indices_dtype)
@@ -773,6 +802,15 @@ def test_take():
     verify_take((3,4), [0, 2], axis=0, mode="fast")
     verify_take((3,4), [0, 2], axis=1, mode="fast")
 
+def test_gather():
+    verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]])
+    verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5)))
+    verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(4, 7, 5)))
+    verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
+    verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
+    verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
+    verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
+
 def test_gather_nd():
     for indices_dtype in ['int32', 'float32']:
         verify_gather_nd((4,), [[1.8]], indices_dtype)