topi.not_equal
topi.greater_equal
topi.less_equal
+ topi.all
topi.logical_and
topi.logical_or
topi.logical_not
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
+.. autofunction:: topi.all
.. autofunction:: topi.max
.. autofunction:: topi.sum
.. autofunction:: topi.min
tvm.relay.greater_equal
tvm.relay.less
tvm.relay.less_equal
+ tvm.relay.all
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
.. autofunction:: tvm.relay.greater_equal
.. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal
+.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not
TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
/*!
+ * \brief logical And of of source expression over axis
+ * \param source The source expression.
+ * \param axis List of iteration variables that will be used for reduction.
+ */
+TVM_DLL Expr all(Expr source, Array<IterVar> axis);
+
+/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl
+def _reduce_all():
+ def _impl(inputs, attr, params):
+ axis = params.pop(inputs[1].name_hint).asnumpy()
+ axis = tuple(axis)
+ return AttrCvt(
+ op_name='all',
+ extras={'axis': axis},
+ transforms={'keep_dims':'keepdims'},
+ ignores=['name', 'Tidx'])([inputs[0]], attr)
+ return _impl
+
def _square():
def _impl(inputs, attr, params):
return _op.multiply(inputs[0], inputs[0])
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'Add' : _elemwise('add'),
+ 'All' : _reduce_all(),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
_reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
+_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
return _make.sum(data, axis, keepdims, exclude)
+def all(data, axis=None, keepdims=False, exclude=False):
+ """Computes the logical AND of boolean array elements over given axes.
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input boolean tensor
+
+ axis : None or int or tuple of int
+ Axis or axes along which a sum is performed. The default, axis=None,
+ will sum all of the elements of the input array. If axis is
+ negative it counts from the last to the first axis.
+
+ keepdims : bool
+ If this is set to True, the axes which are reduced are left in the result as
+ dimensions with size one. With this option, the result will broadcast
+ correctly against the input array.
+
+ exclude : bool
+ If `exclude` is true, reduction will be performed on the axes that are
+ NOT in axis instead.
+
+ Returns
+ -------
+ result : relay.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ data = relay.Constant(tvm.nd.array([[[ True, True, True],
+ [ True, True, True],
+ [False, True, False]],
+ [[ True, False, False],
+ [ True, True, False],
+ [False, True, True]]]))
+
+ relay.all(data, axis=1)
+ # [[False, True, False],
+ # [False, False, False]]
+
+ relay.all(data, axis=0)
+ # [[ True, False, False],
+ # [ True, True, False],
+ # [False, True, False]]
+
+ """
+ axis = [axis] if axis and isinstance(axis, int) else axis
+ return _make.all(data, axis, keepdims, exclude)
+
+
def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
- NOT in axis instead.
+ NOT in axis instead.
Returns
-------
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
+Expr all(Expr source, Array<IterVar> rdom) {
+ CHECK(source.type().is_bool());
+ Var x("x", source.type()), y("y", source.type());
+ Expr result = ir::And::make(x, y);
+ Expr identity_element = make_const(source.type(), true);
+ ir::CommReducer combiner =
+ ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+}
+
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+Array<Tensor> AllCompute(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
+ return ReduceCompute(attrs, inputs, out_type, target, topi::all);
+}
+
+
+RELAY_REGISTER_REDUCE_OP("all")
+.describe(R"code(Computes the logical AND of boolean array elements over given axes.
+
+Example::
+
+ data = [[[ True, True, True],
+ [ True, True, True],
+ [False, True, False]],
+ [[ True, False, False],
+ [ True, True, False],
+ [False, True, True]]]
+
+ all(data, axis=1)
+ [[False, True, False],
+ [False, False, False]]
+
+ all(data, axis=0)
+ [[ True, False, False],
+ [ True, True, False],
+ [False, True, False]]
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.ReduceAttrs")
+.set_support_level(4)
+.add_type_rel("Reduce", ReduceRel)
+.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
+.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
+
Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
#######################################################################
+# All
+# ---
+def test_forward_all():
+ """Test the All operator."""
+ np_data = np.random.choice([True, False], size=(5, 7, 11))
+ tf.reset_default_graph()
+ in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
+ tf.reduce_all(in_data, name="all")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
+
+#######################################################################
# Relational operators
# --------------------
def _test_forward_rel_op(data, func):
test_forward_reduce()
test_forward_mean()
test_forward_reduce_prod()
+ test_forward_all()
# General
test_forward_multi_input()
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
+ dtype = "bool" if ref_func in [np.all] else dtype
x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
return
func = relay.Function([x], z)
- x_data = np.random.uniform(size=data).astype(dtype)
+ x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
+ else np.random.uniform(size=data).astype(dtype)
+
if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
+ [relay.all, np.all],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
+ verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
}
/*!
+* \brief Creates an operation that computes the logical AND of elements
+* over a given axis
+*
+* \param data The input boolean tensor
+* \param axis The axes to reduce. If axis is empty, the operation will
+* perform logical AND over all elements of the array.
+* \param keepdims If this is set to true, the axes which are reduced are
+* left in the result as dimensions with size one. This enables the result
+* to broadcast correctly against the input array.
+* \param atleast1d Whether the output need to be atleast1d.
+*
+* \return A Tensor whose op member is the all operation
+*/
+inline Tensor all(const Tensor& data,
+ const Array<Integer>& axis,
+ bool keepdims = false,
+ bool atleast1d = false) {
+ return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
+}
+
+/*!
* \brief Creates an operation that finds the minimum of elements over
* a given axis.
*
return cpp.sum(data, axis, keepdims)
+def all(data, axis=None, keepdims=False):
+ """Logical AND of array elements over a given axis or a list of axes
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ The input tvm boolean tensor
+
+ axis : None or int or tuple of int
+ Axis or axes along which a logical AND is performed.
+ The default, axis=None, will perform logical AND over all elements of the input array.
+ If axis is negative it counts from the last to the first axis.
+
+ keepdims : bool
+ If this is set to True, the axes which are reduced are left in the result as dimensions
+ with size one.
+ With this option, the result will broadcast correctly against the input array.
+
+ Returns
+ -------
+ ret : tvm.Tensor
+ """
+ return cpp.all(data, axis, keepdims)
+
+
def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
});
+TVM_REGISTER_GLOBAL("topi.all")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
+ });
+
/* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
out_dtype = dtype
if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims)
+ elif type == "all":
+ B = topi.all(A, axis=axis, keepdims=keepdims)
elif type == "max":
B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min":
foo = tvm.build(s, [A, B], device, name=type)
# Test
- in_npy = np.random.uniform(size=in_shape).astype(dtype)
- in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
+ if dtype == 'bool':
+ in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
+ else:
+ in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
+ in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
+
if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
+ elif type == "all" and dtype == 'bool':
+ out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min":
def test_reduce_map():
+
verify_reduce_map_ele(in_shape=(32,),
axis=0,
keepdims=False,
type="argmax")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
- axis=(1, 2, 3),
- keepdims=True,
- type="sum")
+ axis=(1, 2, 3),
+ keepdims=True,
+ type="sum")
+ verify_reduce_map_ele(in_shape=(2, 3),
+ axis=None,
+ keepdims=True,
+ type="all",
+ dtype='bool')
verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
- axis=(1,),
- keepdims=False,
- type="max")
+ axis=(1,),
+ keepdims=False,
+ type="max")
+ verify_reduce_map_ele(in_shape=(32, 128, 24),
+ axis=None,
+ keepdims=True,
+ type="sum")
verify_reduce_map_ele(in_shape=(32, 128, 24),
- axis=None,
- keepdims=True,
- type="sum")
+ axis=None,
+ keepdims=True,
+ dtype='bool',
+ type="all")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
- axis=(0, 2),
- keepdims=False,
- type="min")
+ axis=(0, 2),
+ keepdims=False,
+ type="min")
verify_reduce_map_ele(in_shape=(32, 128),
axis=1,
keepdims=True,