topi.logical_not
topi.arange
topi.stack
+ topi.repeat
+ topi.tile
topi.layout_transform
topi.image.resize
.. autofunction:: topi.less
.. autofunction:: topi.arange
.. autofunction:: topi.stack
+.. autofunction:: topi.repeat
+.. autofunction:: topi.tile
.. autofunction:: topi.layout_transform
topi.nn
tvm.relay.split
tvm.relay.arange
tvm.relay.stack
+ tvm.relay.repeat
+ tvm.relay.tile
**Level 4: Broadcast and Reductions**
.. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange
.. autofunction:: tvm.relay.stack
+.. autofunction:: tvm.relay.repeat
+.. autofunction:: tvm.relay.tile
Level 4 Definitions
}
}; // struct StackAttrs
+/*! \brief Attributes used in repeat operators */
+struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
+ Integer repeats;
+ Integer axis;
+ TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
+ TVM_ATTR_FIELD(repeats)
+ .describe("The number of repetitions for each element.");
+ TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
+ .describe(" The axis along which to repeat values.");
+ }
+}; // struct RepeatAttrs
+
+/*! \brief Attributes used in tile operators */
+struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
+ Array<Integer> reps;
+ TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
+ TVM_ATTR_FIELD(reps)
+ .describe("The number of times for repeating the tensor a."
+ "Each dim sizeof reps must be a positive integer.");
+ }
+}; // struct TileAttrs
+
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
return _op.nn.dropout(inputs[0], rate=rate)
+def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument
+ return inputs
+
+
def _mx_batch_norm(inputs, attrs):
if attrs.get_bool("output_mean_var", False):
raise RuntimeError("batch_norm do not support output_mean_var")
return _op.arange(**new_attrs)
+def _mx_repeat(inputs, attrs):
+ assert len(inputs) == 1
+ new_attrs = {}
+ new_attrs["repeats"] = attrs.get_int("repeats")
+ new_attrs["axis"] = attrs.get_int("axis", 0)
+ return _op.repeat(inputs[0], **new_attrs)
+
+
+def _mx_tile(inputs, attrs):
+ assert len(inputs) == 1
+ new_attrs = {}
+ new_attrs["reps"] = attrs.get_int_tuple("reps")
+ return _op.tile(inputs[0], **new_attrs)
+
+
def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
+ "repeat" : _mx_repeat,
+ "tile" : _mx_tile,
+ "BlockGrad" : _mx_BlockGrad,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
+_reg.register_schedule("repeat", schedule_broadcast)
+_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
return _make.stack(data, axis)
+def repeat(data, repeats, axis):
+ """Repeats elements of an array.
+ By default, repeat flattens the input array into 1-D and then repeats the elements.
+
+ repeats : int
+ The number of repetitions for each element.
+
+ axis: int
+ The axis along which to repeat values. The negative numbers are interpreted
+ counting from the backward. By default, use the flattened input array, and
+ return a flat output array.
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = [[1, 2], [3, 4]]
+ relay.repeat(x, repeats=2) = [1., 1., 2., 2., 3., 3., 4., 4.]
+
+ relay.repeat(x, repeats=2, axis=1) = [[1., 1., 2., 2.],
+ [3., 3., 4., 4.]]
+ """
+ return _make.repeat(data, repeats, axis)
+
+
+def tile(data, reps):
+ """Repeats the whole array multiple times.
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ reps : tuple of int
+ The number of times repeating the tensor data.
+
+ .. note::
+ Each dim size of reps must be a positive integer. If reps has length d,
+ the result will have dimension of max(d, data.ndim); If data.ndim < d,
+ data is promoted to be d-dimensional by prepending new axes.
+ If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = [[1, 2], [3, 4]]
+ relay.tile(x, reps=(2,3)) = [[1., 2., 1., 2., 1., 2.],
+ [3., 4., 3., 4., 3., 4.],
+ [1., 2., 1., 2., 1., 2.],
+ [3., 4., 3., 4., 3., 4.]]
+
+ relay.tile(x, reps=(2,)) = [[1., 2., 1., 2.],
+ [3., 4., 3., 4.]]
+ """
+
+ return _make.tile(data, reps)
+
+
def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+// repeat operator
+TVM_REGISTER_NODE_TYPE(RepeatAttrs);
+
+bool RepeatRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, result]
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "repeat: expect input type to be TensorType but get "
+ << types[0];
+ return false;
+ }
+ const auto* param = attrs.as<RepeatAttrs>();
+ const int ndim = static_cast<int>(data->shape.size());
+ const int repeats = param->repeats;
+ const int axis = param->axis;
+ CHECK(repeats >= 1)
+ << "repeat only accepts `repeats >= 1`"
+ << ", but got repeats = " << repeats;
+ CHECK(-ndim - 1 <= axis && axis <= ndim)
+ << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis
+ << ", and data.ndim = " << ndim;
+ const int pivot = axis < 0 ? ndim + axis : axis;
+ std::vector<IndexExpr> oshape;
+ oshape.reserve(ndim + repeats);
+ for (int i = 0; i < pivot; ++i) {
+ oshape.emplace_back(data->shape[i]);
+ }
+ oshape.emplace_back(data->shape[pivot] * repeats);
+ for (int i = pivot + 1; i < ndim; ++i) {
+ oshape.emplace_back(data->shape[i]);
+ }
+ reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ return true;
+}
+
+Array<Tensor> RepeatCompute(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
+ const RepeatAttrs *param = attrs.as<RepeatAttrs>();
+ CHECK(param != nullptr);
+ return { topi::repeat(inputs[0], param->repeats, param->axis) };
+}
+
+Expr MakeRepeat(Expr data,
+ int repeats,
+ int axis) {
+ auto attrs = make_node<RepeatAttrs>();
+ attrs->repeats = repeats;
+ attrs->axis = axis;
+ static const Op& op = Op::Get("repeat");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.repeat")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
+});
+
+RELAY_REGISTER_OP("repeat")
+.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
+
+- **data**: The input data to the operator.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.Repeat")
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(1)
+.add_type_rel("Repeat", RepeatRel)
+.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
+.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+
+// tile operator
+TVM_REGISTER_NODE_TYPE(TileAttrs);
+
+bool TileRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, result]
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "tile: expect input type to be TensorType but get "
+ << types[0];
+ return false;
+ }
+ const auto* param = attrs.as<TileAttrs>();
+ const size_t ndim = data->shape.size();
+ const Array<Integer>& reps = param->reps;
+ // check dimension match
+ CHECK(!reps.defined())
+ << "repetition array is not defined. data.ndim = " << ndim;
+ const size_t rndim = reps.size();
+ size_t tndim = (ndim > rndim) ? ndim : rndim;
+ // re-construct data shape or reps shape
+ std::vector<IndexExpr> data_shape;
+ std::vector<IndexExpr> reps_shape;
+ data_shape.reserve(tndim);
+ reps_shape.reserve(tndim);
+ if (ndim == rndim) {
+ for (size_t i = 0; i < tndim; ++i) {
+ data_shape.emplace_back(data->shape[i]);
+ reps_shape.emplace_back(reps[i]);
+ }
+ } else if (ndim > rndim) {
+ for (size_t i = 0; i < ndim; ++i)
+ data_shape.emplace_back(data->shape[i]);
+ for (size_t i = 0; i < (ndim - rndim); ++i)
+ reps_shape.emplace_back(1);
+ for (size_t i = 0; i < rndim; ++i)
+ reps_shape.emplace_back(reps[i]);
+ } else {
+ for (size_t i = 0; i < rndim; ++i)
+ reps_shape.emplace_back(reps[i]);
+ }
+ std::vector<IndexExpr> oshape;
+ oshape.reserve(tndim);
+ for (size_t i = 0; i < tndim; ++i) {
+ oshape.emplace_back(data_shape[i] * reps_shape[i]);
+ }
+ reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ return true;
+}
+
+Array<Tensor> TileCompute(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
+ const TileAttrs *param = attrs.as<TileAttrs>();
+ CHECK(param != nullptr);
+ return { topi::tile(inputs[0], param->reps) };
+}
+
+Expr MakeTile(Expr data,
+ Array<Integer> reps) {
+ auto attrs = make_node<TileAttrs>();
+ attrs->reps = reps;
+ static const Op& op = Op::Get("tile");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.tile")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
+});
+
+RELAY_REGISTER_OP("tile")
+.describe(R"code(Repeat the whole array multiple times.
+
+- **data**: The input data to the operator.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.Tile")
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(1)
+.add_type_rel("Tile", TileRel)
+.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
+.set_attr<TOpPattern>("TOpPattern", kBroadcast);
+
// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
return out;
}
+/*!
+* \brief Creates an operation to repeat elements of an array
+*
+* \param x The input tensor
+* \param repeats The number of repetitions for each element
+* \param axis The axis along which to repeat values (allows
+* negative indices as offsets from the last dimension)
+* \param name The name of the operation
+* \param tag The tag to mark the operation
+*
+* \return A Tensor whose op member is the repeat operation
+*/
+inline Tensor repeat(const Tensor& x,
+ int repeats,
+ int axis,
+ std::string name = "tensor",
+ std::string tag = kBroadcast) {
+ int ndim = static_cast<int>(x->shape.size());
+ CHECK(-ndim - 1 <= axis && axis <= ndim)
+ << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
+ << ", but got axis = " << axis
+ << ", and data.ndim = " << ndim;
+ CHECK(repeats >= 1)
+ << "repeat only accepts `repeats >= 1`"
+ << ", but got repeats = " << repeats;
+ if (axis < 0) {
+ // Calculate offset from last dimension
+ axis += ndim;
+ }
+ Array<Expr> new_shape;
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
+ new_shape.push_back(x->shape[i]);
+ }
+ new_shape.push_back(repeats * x->shape[axis]);
+ for (size_t i = axis + 1; i < x->shape.size(); ++i) {
+ new_shape.push_back(x->shape[i]);
+ }
+
+ return compute(
+ new_shape, [&](const Array<Var>& indices) {
+ Array<Expr> idx;
+ for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
+ idx.push_back(indices[i]);
+ }
+ idx.push_back(indices[axis] / repeats);
+ for (size_t i = axis + 1; i < indices.size(); ++i) {
+ idx.push_back(indices[i]);
+ }
+ return x(idx);
+ }, name, tag);
+}
+
+/*!
+* \brief Creates an operation to tile elements of an array
+*
+* \param x The input tensor
+* \param reps The number of times for repeating the tensor
+* \param name The name of the operation
+* \param tag The tag to mark the operation
+*
+* \return A Tensor whose op member is the tile operation
+*/
+inline Tensor tile(const Tensor& x,
+ Array<Integer> reps,
+ std::string name = "tensor",
+ std::string tag = kBroadcast) {
+ size_t ndim = x->shape.size();
+ size_t rdim = reps.size();
+ size_t tdim = (ndim > rdim) ? ndim : rdim;
+ Array<Expr> data_shape;
+ Array<Expr> reps_shape;
+ Array<Expr> new_shape;
+ if (ndim == rdim) {
+ for (size_t i = 0; i < ndim; ++i) {
+ data_shape.push_back(x->shape[i]);
+ reps_shape.push_back(reps[i]);
+ }
+ } else if (ndim > rdim) {
+ for (size_t i = 0; i < ndim; ++i)
+ data_shape.push_back(x->shape[i]);
+ for (size_t i = 0; i < (ndim - rdim); ++i)
+ reps_shape.push_back(1);
+ for (size_t i = 0; i < rdim; ++i)
+ reps_shape.push_back(reps[i]);
+ } else {
+ for (size_t i = 0; i < (rdim - ndim); ++i)
+ data_shape.push_back(1);
+ for (size_t i = 0; i < ndim; ++i)
+ data_shape.push_back(x->shape[i]);
+ for (size_t i = 0; i < rdim; ++i)
+ reps_shape.push_back(reps[i]);
+ }
+ for (size_t i = 0; i < tdim; ++i)
+ new_shape.push_back(data_shape[i] * reps_shape[i]);
+
+ return compute(
+ new_shape, [&](const Array<Var>& indices) {
+ Array<Expr> idx;
+ if (ndim >= rdim) {
+ for (size_t i = 0; i < ndim; ++i)
+ idx.push_back(indices[i] % x->shape[i]);
+ } else {
+ for (size_t i = 0; i < ndim; ++i)
+ idx.push_back(indices[rdim - ndim + i] % x->shape[i]);
+ }
+ return x(idx);
+ }, name, tag);
+}
+
/*!
* \brief Gather elements from a n-dimension array.
*
return cpp.arange(start, stop, step, dtype)
+def repeat(a, repeats, axis):
+ """Repeats elements of an array.
+
+ Parameters
+ ----------
+ a : tvm.Tensor
+ The tensor to be repeated.
+
+ repeats: int, required
+ Number of repetitions for each element
+
+ axis: int, optional
+ The axis along which to repeat values
+
+ Returns
+ -------
+ ret : tvm.Tensor
+ """
+ return cpp.repeat(a, repeats, axis)
+
+
+def tile(a, reps):
+ """Repeats the whole array multiple times.
+
+ Parameters
+ ----------
+ a : tvm.Tensor
+ The tensor to be tiled.
+
+ reps: tuple of ints, required
+ The number of times for repeating the tensor
+
+ Returns
+ -------
+ ret : tvm.Tensor
+ """
+ return cpp.tile(a, reps)
+
+
def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout
*rv = arange(args[0], args[1], args[2], args[3]);
});
+TVM_REGISTER_GLOBAL("topi.repeat")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = repeat(args[0], args[1], args[2]);
+});
+
+TVM_REGISTER_GLOBAL("topi.tile")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = tile(args[0], args[1]);
+});
+
TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]);
for device in get_all_backend():
check_device(device)
+def verify_repeat(in_shape, repeats, axis):
+ A = tvm.placeholder(shape=in_shape, name="A")
+ B = topi.repeat(A, repeats, axis)
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_broadcast(B)
+ foo = tvm.build(s, [A, B], device, name="repeat")
+ data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
+ out_npy = np.repeat(data_npy, repeats, axis)
+ data_nd = tvm.nd.array(data_npy, ctx)
+ out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
+ foo(data_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+ for device in get_all_backend():
+ check_device(device)
+
+def verify_tile(in_shape, reps):
+ A = tvm.placeholder(shape=in_shape, name="A")
+ B = topi.tile(A, reps)
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_broadcast(B)
+ foo = tvm.build(s, [A, B], device, name="tile")
+ data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
+ out_npy = np.tile(data_npy, reps)
+ data_nd = tvm.nd.array(data_npy, ctx)
+ out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
+ foo(data_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+ for device in get_all_backend():
+ check_device(device)
+
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
+def test_repeat():
+ verify_repeat((2,), 1, 0)
+ verify_repeat((3, 2), 2, 0)
+ verify_repeat((3, 2, 4), 3, 1)
+ verify_repeat((1, 3, 2, 4), 4, -1)
+
+def test_tile():
+ verify_tile((3, 2), (2, 3))
+ verify_tile((3, 2, 5), (2,))
+ verify_tile((3, ), (2, 3, 3))
def test_layout_transform():
in_shape = (1, 32, 8, 8)
test_gather_nd()
test_arange()
test_layout_transform()
+ test_repeat()
+ test_tile()