topi.repeat
topi.tile
topi.shape
+ topi.ndarray_size
topi.layout_transform
topi.image.resize
topi.argsort
.. autofunction:: topi.repeat
.. autofunction:: topi.tile
.. autofunction:: topi.shape
+.. autofunction:: topi.ndarray_size
.. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk
tvm.relay.collapse_sum_like
tvm.relay.slice_like
tvm.relay.shape_of
+ tvm.relay.contrib.ndarray_size
tvm.relay.layout_transform
tvm.relay.device_copy
tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.collapse_sum_like
.. autofunction:: tvm.relay.slice_like
.. autofunction:: tvm.relay.shape_of
+.. autofunction:: tvm.relay.contrib.ndarray_size
.. autofunction:: tvm.relay.layout_transform
.. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
- /*! \brief addtitional attributes of the operation*/
+ /*! \brief additional attributes of the operation*/
Map<std::string, NodeRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final {
}
}; // struct SequenceMaskAttrs.
+/*! \brief Attributes for ndarray_size operator */
+struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") {
+ TVM_ATTR_FIELD(dtype)
+ .describe("Target data type")
+ .set_default(NullValue<DataType>());
+ }
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
The name hint of the tensor
tag: str, optional
- Additonal tag information about the compute.
+ Additional tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Sign' : AttrCvt('sign'),
+ 'Size' : AttrCvt('ndarray_size'),
'Slice' : _slice(),
'Softmax' : _softmax(),
'Softplus' : _softplus(),
import topi
from .. import op as reg
-from ..op import OpPattern
+from ..op import schedule_injective, OpPattern
# adaptive_max_pool2d
return topi.generic.schedule_adaptive_pool(outs)
reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+# relay.contrib.ndarray_size
+reg.register_schedule("contrib.ndarray_size", schedule_injective)
"""
output_size = [] or output_size
return _make.adaptive_avg_pool2d(data, output_size, layout)
+
+def ndarray_size(data, dtype="int32"):
+ """Get number of elements of input tensor.
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input tensor.
+
+ dtype : str, optional
+ The target data type.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The number of elements of input tensor.
+ """
+ return _make.ndarray_size(data, dtype)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
+
+TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs);
+
+bool NdarraySizeRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 1);
+ auto tt = types[0].as<TensorTypeNode>();
+ CHECK(tt != nullptr);
+ const auto* param = attrs.as<NdarraySizeAttrs>();
+ CHECK(param != nullptr);
+ reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
+ return true;
+}
+
+Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
+ CHECK_EQ(inputs.size(), 1);
+ const auto* param = attrs.as<NdarraySizeAttrs>();
+ CHECK(param != nullptr);
+ return Array<Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
+}
+
+TVM_REGISTER_API("relay.op.contrib._make.ndarray_size")
+.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
+ auto attrs = make_node<NdarraySizeAttrs>();
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("contrib.ndarray_size");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("contrib.ndarray_size")
+.describe(R"code(Returns a tensor representing the number of elements of input tensor.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.NdarraySizeAttrs")
+.add_argument("data", "Tensor", "The input tensor.")
+.add_type_rel("NdarraySize", NdarraySizeRel)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<TOpPattern>("TOpPattern", kInjective)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ElemwiseArbitraryLayout)
+.set_support_level(10)
+.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
+
} // namespace relay
} // namespace tvm
check_mean((10, 8, 16, 32), axis=(2, 3))
check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True)
+#######################################################################
+# Size
+# ----
+def test_forward_size():
+ def check_size(ishape):
+ np_input = np.random.uniform(size=ishape).astype(np.float32)
+ with tf.Graph().as_default():
+ input = tf.placeholder(shape=np_input.shape, dtype=np_input.dtype, name='input')
+ tf.size(input, name='size')
+ compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
+
+ if tf.__version__ < LooseVersion('1.1'):
+ check_size((10, 8, 16, 32))
+ check_size((10,))
+ check_size(())
+
#######################################################################
# All, Max, Min
# -------------
test_forward_depthtospace()
test_forward_squeeze()
test_forward_pack()
+ test_forward_size()
test_forward_broadcast_to()
test_forward_fill()
test_forward_crop()
tvm.testing.assert_allclose(op_res.asnumpy(),
np.array(shape).astype('int32'))
+def test_ndarray_size():
+ def verify_ndarray_size(shape):
+ x = relay.var("x", shape=shape)
+ func = relay.Function([x], relay.op.contrib.ndarray_size(x))
+ func = run_infer_type(func)
+
+ x_data = np.random.uniform(size=shape).astype("float32")
+ ref_res = np.size(x_data)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(x_data)
+ tvm.testing.assert_allclose(op_res.asnumpy(),
+ ref_res)
+ verify_ndarray_size((2, 3, 5))
+ verify_ndarray_size((2, 3, 5, 7))
+
def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
def start_index(index, odim, idim):
return int(np.floor(index * idim / odim))
test_batch_matmul()
test_shape_of()
test_sequence_mask()
+ test_ndarray_size()
+
}, name, tag);
}
+/*!
+ * \brief Get the size of input tensor.
+ * \param src the input tensor.
+ * \param dtype the type of the elements in the tensor.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return Tensor of input shape.
+ */
+inline Tensor ndarray_size(const Tensor& src,
+ const Type& dtype,
+ const std::string& name = "ndarray_size",
+ const std::string& tag = kInjective) {
+ int ndim = static_cast<int>(src->shape.size());
+ Array<Expr> out_ndarray_size = {1};
+ return compute(out_ndarray_size, [&](const Array<Var>& indices) {
+ Expr ret = 1;
+ for (int i = 0; i < ndim; ++i) {
+ ret *= src->shape[i];
+ }
+ return tvm::cast(dtype, ret);
+ }, name, tag);
+}
+
} // namespace topi
#endif // TOPI_TRANSFORM_H_
Parameters
----------
array : tvm.Tensor
- The source tenosr.
+ The source tensor.
dtype : str, optional
The target data type.
"only support data.ndim >= 2, received data.shape = {}".format(data.shape)
assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis)
return cpp.sequence_mask(data, valid_length, mask_value, axis)
+
+
+def ndarray_size(array, dtype="int32"):
+ """Get the number of elements of input array
+
+ Parameters
+ ----------
+ array : tvm.Tensor
+ The source tensor.
+
+ dtype : str, optional
+ The target data type.
+
+ Returns
+ -------
+ result : tvm.Tensor
+ The resulting tensor.
+ """
+ return cpp.ndarray_size(array, dtype)
*rv = shape(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.ndarray_size")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = ndarray_size(args[0], args[1]);
+});
+
TVM_REGISTER_GLOBAL("topi.split")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
for backend in get_all_backend():
check_device(backend)
+def test_ndarray_size():
+ in_shape = (5, 11, 7)
+ dtype = "int32"
+ A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
+ B = topi.ndarray_size(A, dtype)
+
+ input = np.random.uniform(size=in_shape).astype(A.dtype)
+ output = np.asarray(np.size(input)).astype(dtype)
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ tvm_input = tvm.nd.array(input, ctx=ctx)
+ tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_injective(B)
+ f = tvm.build(s, [A, B], device, name="ndarray_size")
+ f(tvm_input, tvm_output)
+ tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
+
+ for backend in get_all_backend():
+ check_device(backend)
+
+
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
test_tile()
test_shape()
test_sequence_mask()
+ test_ndarray_size()