"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
+
+from tvm.runtime import convert
from tvm.te.hybrid import script
from .. import op as _reg
_reg.register_injective_schedule("dyn.reshape")
+_reg.register_broadcast_schedule("dyn.tile")
@script
def _reshape_shape_func_input_data(data, newshape, ndim):
@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
+
+
+@script
+def _tile_shape_func(data, reps, ndim, tndim, rndim):
+ out = output_tensor((tndim,), "int64")
+
+ if ndim == rndim:
+ for i in const_range(tndim):
+ out[i] = int64(data.shape[i] * reps[i])
+ elif ndim > rndim:
+ ngap = ndim - rndim
+ for i in const_range(ndim):
+ if i < ngap:
+ out[i] = int64(data.shape[i])
+ else:
+ out[i] = int64(data.shape[i] * reps[i - ngap])
+ else:
+ rgap = rndim - ndim
+ for i in const_range(rndim):
+ if i < rgap:
+ out[i] = int64(reps[i])
+ else:
+ out[i] = int64(reps[i] * data.shape[i - rgap])
+ return out
+
+
+@_reg.register_shape_func("dyn.tile", True)
+def tile_shape_func(attrs, inputs, _):
+ """
+ Shape function for dyn.tile op.
+ """
+ reps = inputs[1]
+ ndim = len(inputs[0].shape)
+ rndim = inputs[1].shape[0].value
+ tndim = ndim if ndim > rndim else rndim
+ return [_tile_shape_func(inputs[0], reps, convert(ndim),
+ convert(tndim), convert(rndim))]
data : relay.Expr
The input data to the operator.
- reps : tuple of int
+ reps : tuple of int or relay.Expr
The number of times repeating the tensor data.
Returns
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
"""
-
+ if isinstance(reps, Expr):
+ return _dyn_make.tile(data, reps)
return _make.tile(data, reps)
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>
+#include <vector>
+
namespace tvm {
namespace relay {
namespace dyn {
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+// tile operator
+// TVM_REGISTER_NODE_TYPE(TileAttrs);
+
+bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, reps, result]
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto* reps = types[1].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "tile: expect input type to be TensorType but get " << types[0];
+ return false;
+ }
+ if (reps == nullptr) {
+ CHECK(types[1].as<IncompleteTypeNode>())
+ << "tile: expect input type to be TensorType but get " << types[1];
+ return false;
+ }
+ const IntImmNode* reps_shape = reps->shape[0].as<IntImmNode>();
+ CHECK(reps_shape) << "Parameter reps must have static shape";
+ const size_t ndim = data->shape.size();
+ const size_t rndim = reps_shape->value;
+ size_t tndim = (ndim > rndim) ? ndim : rndim;
+ std::vector<IndexExpr> oshape;
+ oshape.reserve(tndim);
+ for (size_t i = 0; i < tndim; ++i) {
+ oshape.emplace_back(Any());
+ }
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
+ return true;
+}
+
+Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ CHECK_EQ(inputs.size(), 2);
+ const auto* out_ttype = out_type.as<TensorTypeNode>();
+ size_t rndim = inputs[1]->shape[0].as<IntImmNode>()->value;
+ return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)};
+}
+
+Expr MakeTile(Expr data, Expr reps) {
+ auto attrs = make_object<TileAttrs>();
+ static const Op& op = Op::Get("dyn.tile");
+ return Call(op, {data, reps}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.tile").set_body_typed(MakeTile);
+
+RELAY_REGISTER_OP("dyn.tile")
+ .describe(R"code(Repeat the whole array multiple times.
+
+- **data**: The input data to the operator.
+- **reps**: The number of times to repeat the operator.
+
+)code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .set_attrs_type<TileAttrs>()
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("reps", "Tensor", "The number of times to repeat the input on each axis.")
+ .set_support_level(3)
+ .add_type_rel("DynamicTile", TileRel)
+ .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
+
} // namespace dyn
} // namespace relay
} // namespace tvm
class DynamicToStaticMutator : public MixedModeMutator {
public:
- DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {}
+ DynamicToStaticMutator()
+ : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}
private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
static const Op& reshape = Op::Get("reshape");
return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
}
+ } else if (call_node->op == dyn_tile_op_) {
+ if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
+ auto attrs = make_object<TileAttrs>();
+ CHECK_EQ(reps->data->ndim, 1);
+ attrs->reps = ToVector(reps->data);
+ static const Op& op = Op::Get("tile");
+ return Call(op, {call_node->args[0]}, Attrs(attrs), {});
+ }
}
return post;
}
}
const Op& dyn_reshape_op_;
+ const Op& dyn_tile_op_;
};
Expr DynamicToStatic(Function f, IRModule m) {
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
+def test_dyn_tile():
+ def verify_tile(dshape, reps):
+ x = relay.var("x", relay.TensorType(dshape, "float32"))
+ r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
+ z = relay.tile(x, r)
+
+ func = relay.Function([x, r], z)
+ x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
+ ref_res = np.tile(x_data, reps=reps)
+ verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
+ verify_tile((2, 3, 4), (3, 2, 1))
+ verify_tile((2, 3, 4), (1, 2))
+ verify_tile((2, 3), (3, 2, 1))
+
if __name__ == "__main__":
test_dyn_reshape()
test_dyn_shape_reshape()
+ test_dyn_tile()
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
+def test_dynamic_to_static_tile():
+ def verify_tile(shape, reps, oshape):
+ x = relay.var("x", relay.TensorType(shape, "float32"))
+ y = relay.var("y", relay.TensorType(reps, "float32"))
+ z = relay.tile(x, relay.shape_of(y))
+ func = run_infer_type(relay.Function([x, y], z))
+ func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+
+ zz = func2.body
+ assert isinstance(zz, relay.Call)
+ assert zz.op == relay.op.get("tile")
+ assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
+
+ x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
+ y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32")
+ ref_res = np.tile(x_data, reps)
+ verify_func(func2, [x_data, y_data], ref_res)
+
+ verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
+ verify_tile((4, 7), (4, 2), (16, 14))
+
if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()
+ test_dynamic_to_static_tile()
}
/*!
+ * \brief Creates an operation to tile elements of an array
+ *
+ * \param x The input tensor
+ * \param new_shape The shape of the output after tiling
+ * \param rdim The rank of the reps, provided by caller
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the tile operation
+ */
+inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
+ std::string name = "T_tile", std::string tag = kBroadcast) {
+ size_t ndim = x->shape.size();
+ if (is_empty_shape(new_shape)) {
+ return compute(
+ new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
+ } else {
+ return compute(
+ new_shape,
+ [&](const Array<Var>& indices) {
+ Array<PrimExpr> idx;
+ if (ndim >= rdim) {
+ for (size_t i = 0; i < ndim; ++i) {
+ idx.push_back(indexmod(indices[i], x->shape[i]));
+ }
+ } else {
+ for (size_t i = 0; i < ndim; ++i) {
+ idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
+ }
+ }
+ return x(idx);
+ },
+ name, tag);
+ }
+}
+
+/*!
* \brief Gather values along given axis from given indices.
*
* \param data The input data to the operator.