_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
+_reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
return _relay_make.cast(data, dtype)
+def cast_like(data, dtype_like):
+ """Cast input tensor to data type of another tensor.
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+ dtype_like: relay.Expr
+ The tensor to cast to.
+ Returns
+ -------
+ result : relay.Expr
+ The casted result.
+ """
+ from .. import _make as _relay_make
+ return _relay_make.cast_like(data, dtype_like)
+
+
def reinterpret(data, dtype):
"""Reinterpret input tensor to data type.
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+// relay.cast_like
+bool CastLikeRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "cast: expect input type to be TensorType but get "
+ << types[0];
+ return false;
+ }
+ const auto* dtype_like = types[1].as<TensorTypeNode>();
+ if (dtype_like == nullptr) {
+ CHECK(types[1].as<IncompleteTypeNode>())
+ << "cast: expect input type to be TensorType but get "
+ << types[1];
+ return false;
+ }
+ reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype));
+ return true;
+}
+
+
+Array<Tensor> CastLikeCompute(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
+ return { topi::cast(inputs[0], inputs[1]->dtype) };
+}
+
+
+Expr MakeCastLike(Expr data,
+ Expr dtype_like) {
+ static const Op& op = Op::Get("cast_like");
+ return CallNode::make(op, {data, dtype_like}, Attrs(), {});
+}
+
+
+TVM_REGISTER_API("relay._make.cast_like")
+.set_body_typed(MakeCastLike);
+
+RELAY_REGISTER_OP("cast_like")
+.describe(R"code(Cast the data into the type of another tensor.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("dtype_like", "Tensor", "The tensor to cast to.")
+.set_support_level(3)
+.add_type_rel("CastLike", CastLikeRel)
+.set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
+.set_attr<TOpPattern>("TOpPattern", kElemWise)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+
Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
const CastAttrs* param = attrs.as<CastAttrs>();