topi.concatenate
topi.split
topi.take
+ topi.gather
topi.gather_nd
topi.full
topi.full_like
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
+.. autofunction:: topi.gather
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
+ tvm.relay.gather
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
}
};
+struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
+ Integer axis;
+
+ TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
+ TVM_ATTR_FIELD(axis)
+ .set_default(NullValue<Integer>())
+ .describe("The axis over which to select values.");
+ }
+};
+
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
+_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""
+@tvm._ffi.register_object("relay.attrs.GatherAttrs")
+class GatherAttrs(Attrs):
+ """Attributes for transform.gather"""
+
@tvm._ffi.register_object("relay.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes for transform.take"""
return _make._contrib_reverse_reshape(data, list(newshape))
+def gather(data, axis, indices):
+ """Gather values along given axis from given indices.
+
+ E.g. for a 3D tensor, output is computed as:
+
+ .. code-block:: python
+
+ out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
+ out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
+ out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
+
+ ``indices`` must have same shape as ``data``, except at dimension ``axis``
+ which must just be not null. Output will have same shape as ``indices``.
+
+ Parameters
+ ----------
+ data: relay.Expr
+ The input data to the operator.
+
+ axis: int
+ The axis along which to index.
+
+ indices: relay.Expr
+ The indices of values to gather.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ data = [[1, 2], [3, 4]]
+ axis = 1
+ indices = [[0, 0], [1, 0]]
+ relay.gather(data, axis, indices) = [[1, 1], [4, 3]]
+ """
+ return _make.gather(data, axis, indices)
+
+
def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+// gather operator
+TVM_REGISTER_NODE_TYPE(GatherAttrs);
+
+bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, indices, result]
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto* indices = types[1].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "Gather: expect input data type to be TensorType but get " << types[0];
+ return false;
+ }
+ if (indices == nullptr) {
+ CHECK(types[1].as<IncompleteTypeNode>())
+ << "Gather: expect indices type to be TensorType but get " << types[1];
+ return false;
+ }
+ CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
+ const auto param = attrs.as<GatherAttrs>();
+ CHECK(param != nullptr);
+ CHECK(param->axis.defined());
+
+ const auto ndim_data = data->shape.size();
+ const auto ndim_indices = indices->shape.size();
+ int axis = param->axis->value;
+ CHECK_EQ(ndim_data, ndim_indices);
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, ndim_data);
+
+ std::vector<IndexExpr> oshape;
+ oshape.reserve(ndim_data);
+ for (size_t i = 0; i < ndim_data; ++i) {
+ if (i == (size_t)axis) {
+ const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
+ CHECK_GE(*indice_shape_i, 1);
+ } else {
+ CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
+ }
+ oshape.emplace_back(indices->shape[i]);
+ }
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
+ return true;
+}
+
+Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<GatherAttrs>();
+ return {topi::gather(inputs[0], param->axis, inputs[1])};
+}
+
+Expr MakeGather(Expr data, Integer axis, Expr indices) {
+ auto attrs = make_object<GatherAttrs>();
+ attrs->axis = std::move(axis);
+ static const Op& op = Op::Get("gather");
+ return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather);
+
+RELAY_REGISTER_OP("gather")
+ .describe(R"code(Gather values along given axis from given indices.
+
+E.g. for a 3D tensor, output is computed as:
+
+ out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
+ out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
+ out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
+
+``indices`` must have same shape as ``data``, except at dimension ``axis``
+which must just be not null. Output will have same shape as ``indices``.
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<GatherAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input data to the operator.")
+ .add_argument("indices", "Tensor", "The indices of values to gather.")
+ .set_support_level(3)
+ .add_type_rel("Gather", GatherRel)
+ .set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
+
// gather_nd operator
bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
+def test_gather():
+ def verify_gather(data, axis, indices, ref_res):
+ data = np.asarray(data, dtype='float32')
+ indices = np.asarray(indices, dtype='int32')
+ ref_res = np.asarray(ref_res)
+
+ d = relay.var("x", relay.TensorType(data.shape, "float32"))
+ i = relay.var("y", relay.TensorType(indices.shape, "int32"))
+ z = relay.gather(d, axis, i)
+
+ func = relay.Function([d, i], z)
+
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(data, indices)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
+ rtol=1e-5)
+
+ verify_gather([[1, 2], [3, 4]],
+ 1,
+ [[0, 0], [1, 0]],
+ [[1, 1], [4, 3]])
+ verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
+ 0,
+ [[[1, 0, 1], [1, 1, 0]]],
+ [[[6, 1, 8], [9, 10, 5]]])
+ verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448],
+ [0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]],
+ [[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502],
+ [0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]],
+ 1,
+ [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
+ [[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
+ [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]])
+ verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818],
+ [0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]],
+ [[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084],
+ [0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]],
+ 2,
+ [[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
+ [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]],
+ [[[1.6986, 1.6986, 0.3050, 1.6986],
+ [0.7020, 0.7020, -2.1818, -2.1818],
+ [-0.5773, -0.9912, -0.5773, -0.9912],
+ [-1.0720, -1.0720, -1.3915, 0.0835]],
+ [[0.1694, 0.1694, -0.6091, -0.6539],
+ [0.5084, 0.5084, -0.1218, -0.5234],
+ [-1.9537, -2.0078, 0.2374, 0.2374],
+ [-0.5700, 0.1558, -0.5700, 0.1558]]])
+
+
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
}
/*!
+ * \brief Gather values along given axis from given indices.
+ *
+ * \param data The input data to the operator.
+ * \param axis The axis along which to index.
+ * \param indices The indices of values to gather.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return A Tensor whose op member is the gather operation
+ */
+inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
+ std::string name = "T_gather", std::string tag = kInjective) {
+ size_t ndim_d = data->shape.size();
+ size_t ndim_i = indices->shape.size();
+ CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
+ CHECK_EQ(ndim_d, ndim_i);
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, ndim_d);
+ size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
+ CHECK_GE(indices_dim_i, 1);
+ CHECK(indices->dtype.is_int());
+
+ Array<PrimExpr> out_shape;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ out_shape.push_back(indices->shape[i]);
+ }
+
+ return compute(
+ out_shape,
+ [&](const Array<Var>& out_index) {
+ Array<PrimExpr> indices_position;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ indices_position.push_back(out_index[i]);
+ }
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ if (i == (size_t)axis) {
+ real_indices.push_back(indices(indices_position));
+ } else {
+ real_indices.push_back(indices_position[i]);
+ }
+ }
+ return data(real_indices);
+ },
+ name, tag);
+}
+
+/*!
* \brief Gather elements from a n-dimension array.
*
* \param data The source array.
from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
+from .gather_python import gather_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""gather in python"""
+import numpy as np
+
+def gather_python(data, axis, indices):
+ """ Python version of Gather operator
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ Numpy array
+
+ axis: int
+ integer
+
+ indices : numpy.ndarray
+ Numpy array
+
+ Returns
+ -------
+ b_np : numpy.ndarray
+ Numpy array
+ """
+ shape_indices = indices.shape
+ out = np.zeros(shape_indices, dtype=data.dtype)
+ for index in np.ndindex(*shape_indices):
+ new_index = list(index)
+ new_index[axis] = indices[index]
+ out[index] = data[tuple(new_index)]
+ return out
return cpp.take(a, indices, int(axis), mode)
+def gather(data, axis, indices):
+ """Gather values along given axis from given indices.
+
+ E.g. for a 3D tensor, output is computed as:
+
+ .. code-block:: python
+
+ out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
+ out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
+ out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
+
+ ``indices`` must have same shape as ``data``, except at dimension ``axis``
+ which must just be not null. Output will have same shape as ``indices``.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The input data to the operator.
+
+ axis: int
+ The axis along which to index.
+
+ indices : tvm.te.Tensor
+ The indices of the values to extract.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ """
+ return cpp.gather(data, axis, indices)
+
+
def gather_nd(a, indices):
"""Gather elements from a n-dimension array..
*rv = tile(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = gather(args[0], args[1], args[2]);
+});
+
TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = gather_nd(args[0], args[1]);
});
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
+def verify_gather(data, axis, indices):
+ data = np.asarray(data)
+ indices = np.asarray(indices)
+
+ var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data")
+ var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices")
+ out_tensor = topi.gather(var_data, axis, var_indices)
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.testing.get_injective_schedule(device)(out_tensor)
+
+ func = tvm.build(s, [var_data, var_indices, out_tensor] , device, name="gather")
+ out_npys = topi.testing.gather_python(data, axis, indices)
+
+ data_nd = tvm.nd.array(data, ctx)
+ indices_nd = tvm.nd.array(indices, ctx)
+ out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=data.dtype.name)
+ func(data_nd, indices_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
+
+ for device in get_all_backend():
+ check_device(device)
+
def verify_gather_nd(src_shape, indices_src, indices_dtype):
src_dtype = "float32"
indices_src = np.array(indices_src, dtype=indices_dtype)
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")
+def test_gather():
+ verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]])
+ verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5)))
+ verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(4, 7, 5)))
+ verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
+ verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
+ verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
+ verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
+
def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
verify_gather_nd((4,), [[1.8]], indices_dtype)