[Relay][TOPI] operator All (#3124)
authorYong Wu <ywu118@alumni.jh.edu>
Mon, 20 May 2019 18:56:22 +0000 (11:56 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 20 May 2019 18:56:22 +0000 (11:56 -0700)
* [Relay][TOPI] operator All

* Update tests/python/frontend/tensorflow/test_forward.py

Co-Authored-By: yongwww <55wuyong@163.com>
* fix comments

* change to level 4

14 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/expr_operator.h
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/op/_reduce.py
python/tvm/relay/op/reduce.py
src/lang/expr_operator.cc
src/relay/op/tensor/reduce.cc
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_op_level4.py
topi/include/topi/reduction.h
topi/python/topi/reduction.py
topi/src/topi.cc
topi/tests/python/test_topi_reduce.py

index eaa5dac..0b217d4 100644 (file)
@@ -88,6 +88,7 @@ List of operators
    topi.not_equal
    topi.greater_equal
    topi.less_equal
+   topi.all
    topi.logical_and
    topi.logical_or
    topi.logical_not
@@ -140,6 +141,7 @@ topi
 .. autofunction:: topi.gather_nd
 .. autofunction:: topi.full
 .. autofunction:: topi.full_like
+.. autofunction:: topi.all
 .. autofunction:: topi.max
 .. autofunction:: topi.sum
 .. autofunction:: topi.min
index cd56772..836f8f3 100644 (file)
@@ -135,6 +135,7 @@ This level enables additional math and transform operators.
    tvm.relay.greater_equal
    tvm.relay.less
    tvm.relay.less_equal
+   tvm.relay.all
    tvm.relay.logical_and
    tvm.relay.logical_or
    tvm.relay.logical_not
@@ -277,6 +278,7 @@ Level 4 Definitions
 .. autofunction:: tvm.relay.greater_equal
 .. autofunction:: tvm.relay.less
 .. autofunction:: tvm.relay.less_equal
+.. autofunction:: tvm.relay.all
 .. autofunction:: tvm.relay.logical_and
 .. autofunction:: tvm.relay.logical_or
 .. autofunction:: tvm.relay.logical_not
index 2e1348e..f289bdd 100644 (file)
@@ -429,6 +429,13 @@ TVM_DLL Expr abs(Expr x);
 TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
 
 /*!
+ * \brief logical And of of source expression over axis
+ * \param source The source expression.
+ * \param axis List of iteration variables that will be used for reduction.
+ */
+TVM_DLL Expr all(Expr source, Array<IterVar> axis);
+
+/*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
index 11026b9..7fe82ea 100644 (file)
@@ -767,6 +767,17 @@ def _sum():
             ignores=['name', 'Tidx'])([inputs[0]], attr)
     return _impl
 
+def _reduce_all():
+    def _impl(inputs, attr, params):
+        axis = params.pop(inputs[1].name_hint).asnumpy()
+        axis = tuple(axis)
+        return AttrCvt(
+            op_name='all',
+            extras={'axis': axis},
+            transforms={'keep_dims':'keepdims'},
+            ignores=['name', 'Tidx'])([inputs[0]], attr)
+    return _impl
+
 def _square():
     def _impl(inputs, attr, params):
         return _op.multiply(inputs[0], inputs[0])
@@ -1180,6 +1191,7 @@ _identity_list = []
 # for N to 1 mapping, currently not supported(?)
 _convert_map = {
     'Add'                               : _elemwise('add'),
+    'All'                               : _reduce_all(),
     'ArgMax'                            : _argx(_op.argmax, 'argmax'),
     'ArgMin'                            : _argx(_op.argmin, 'argmin'),
     'AvgPool'                           : _pooling('avg_pool'),
index b97e3a8..b7c9a79 100644 (file)
@@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target):
 _reg.register_schedule("argmax", _schedule_reduce)
 _reg.register_schedule("argmin", _schedule_reduce)
 _reg.register_schedule("sum", _schedule_reduce)
+_reg.register_schedule("all", _schedule_reduce)
 _reg.register_schedule("max", _schedule_reduce)
 _reg.register_schedule("min", _schedule_reduce)
 _reg.register_schedule("prod", _schedule_reduce)
index 9d58a92..0f25946 100644 (file)
@@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False):
     return _make.sum(data, axis, keepdims, exclude)
 
 
+def all(data, axis=None, keepdims=False, exclude=False):
+    """Computes the logical AND of boolean array elements over given axes.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input boolean tensor
+
+    axis : None or int or tuple of int
+        Axis or axes along which a sum is performed. The default, axis=None,
+        will sum all of the elements of the input array. If axis is
+        negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as
+        dimensions with size one. With this option, the result will broadcast
+        correctly against the input array.
+
+    exclude : bool
+        If `exclude` is true, reduction will be performed on the axes that are
+        NOT in axis instead.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+    data = relay.Constant(tvm.nd.array([[[ True,  True,  True],
+                                         [ True,  True,  True],
+                                         [False,  True, False]],
+                                        [[ True, False, False],
+                                         [ True,  True, False],
+                                         [False,  True,  True]]]))
+
+    relay.all(data, axis=1)
+    # [[False,  True, False],
+    # [False, False, False]]
+
+    relay.all(data, axis=0)
+    # [[ True, False, False],
+    # [ True,  True, False],
+    # [False,  True, False]]
+
+    """
+    axis = [axis] if axis and isinstance(axis, int) else axis
+    return _make.all(data, axis, keepdims, exclude)
+
+
 def max(data, axis=None, keepdims=False, exclude=False):
     """ Computes the max of array elements over given axes.
 
@@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
@@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
-      NOT in axis instead.
+        NOT in axis instead.
 
     Returns
     -------
index 4504ee2..8537f17 100644 (file)
@@ -393,6 +393,16 @@ Expr sum(Expr source, Array<IterVar> rdom) {
   return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
+Expr all(Expr source, Array<IterVar> rdom) {
+  CHECK(source.type().is_bool());
+  Var x("x", source.type()), y("y", source.type());
+  Expr result = ir::And::make(x, y);
+  Expr identity_element = make_const(source.type(), true);
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+}
+
 Expr max(Expr source, Array<IterVar> rdom) {
   Var x("x", source.type()), y("y", source.type());
   Expr result = ir::Max::make(x, y);
index a4ebd1e..647e4d0 100644 (file)
@@ -355,6 +355,43 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
+Array<Tensor> AllCompute(const Attrs& attrs,
+                         const Array<Tensor>& inputs,
+                         const Type& out_type,
+                         const Target& target) {
+  return ReduceCompute(attrs, inputs, out_type, target, topi::all);
+}
+
+
+RELAY_REGISTER_REDUCE_OP("all")
+.describe(R"code(Computes the logical AND of boolean array elements over given axes.
+
+Example::
+
+  data = [[[ True,  True,  True],
+           [ True,  True,  True],
+           [False,  True, False]],
+          [[ True, False, False],
+           [ True,  True, False],
+           [False,  True,  True]]]
+
+  all(data, axis=1)
+  [[False,  True, False],
+   [False, False, False]]
+
+  all(data, axis=0)
+  [[ True, False, False],
+   [ True,  True, False],
+   [False,  True, False]]
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.ReduceAttrs")
+.set_support_level(4)
+.add_type_rel("Reduce", ReduceRel)
+.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
+.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
+
 Array<Tensor> MaxCompute(const Attrs& attrs,
                          const Array<Tensor>& inputs,
                          const Type& out_type,
index e4626e0..023cdf5 100644 (file)
@@ -1598,6 +1598,17 @@ def test_forward_mean():
     check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
 
 #######################################################################
+# All
+# ---
+def test_forward_all():
+    """Test the All operator."""
+    np_data = np.random.choice([True, False], size=(5, 7, 11))
+    tf.reset_default_graph()
+    in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
+    tf.reduce_all(in_data, name="all")
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
+
+#######################################################################
 # Relational operators
 # --------------------
 def _test_forward_rel_op(data, func):
@@ -1718,6 +1729,7 @@ if __name__ == '__main__':
     test_forward_reduce()
     test_forward_mean()
     test_forward_reduce_prod()
+    test_forward_all()
 
     # General
     test_forward_multi_input()
index 0e44bf8..aac4a6d 100644 (file)
@@ -138,6 +138,7 @@ def test_where():
 def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
     test_func = funcs[0]
     ref_func = funcs[1]
+    dtype = "bool" if ref_func in [np.all] else dtype
 
     x = relay.var("x", relay.TensorType(data, dtype))
     z = test_func(x, axis, keepdims, exclude)
@@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
         return
 
     func = relay.Function([x], z)
-    x_data = np.random.uniform(size=data).astype(dtype)
+    x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
+        else np.random.uniform(size=data).astype(dtype)
+
     if ref_func in [np.sum]:
         ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
     elif ref_func in [np.max, np.min, np.mean, np.prod]:
@@ -194,6 +197,7 @@ def test_reduce_functions():
                  [relay.min, np.min],
                  [relay.mean, np.mean],
                  [relay.prod, np.prod],
+                 [relay.all, np.all],
                  [relay.argmin, _with_keepdims(np.argmin)],
                  [relay.argmax, _with_keepdims(np.argmax)]]:
         verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
@@ -203,6 +207,7 @@ def test_reduce_functions():
         verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
         verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
         verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
+        verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
         verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
         verify_reduce(func, (4, 4, 3), None, False, False, ())
         verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
index b24c457..09d1b4b 100644 (file)
@@ -369,6 +369,27 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
 }
 
 /*!
+* \brief Creates an operation that computes the logical AND of elements
+* over a given axis
+*
+* \param data The input boolean tensor
+* \param axis The axes to reduce. If axis is empty, the operation will
+* perform logical AND over all elements of the array.
+* \param keepdims If this is set to true, the axes which are reduced are
+* left in the result as dimensions with size one. This enables the result
+* to broadcast correctly against the input array.
+* \param atleast1d Whether the output need to be atleast1d.
+*
+* \return A Tensor whose op member is the all operation
+*/
+inline Tensor all(const Tensor& data,
+                  const Array<Integer>& axis,
+                  bool keepdims = false,
+                  bool atleast1d = false) {
+  return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
+}
+
+/*!
 * \brief Creates an operation that finds the minimum of elements over
 * a given axis.
 *
index ce1326b..5079bf4 100644 (file)
@@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False):
     return cpp.sum(data, axis, keepdims)
 
 
+def all(data, axis=None, keepdims=False):
+    """Logical AND of array elements over a given axis or a list of axes
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tvm boolean tensor
+
+    axis : None or int or tuple of int
+        Axis or axes along which a logical AND is performed.
+        The default, axis=None, will perform logical AND over all elements of the input array.
+        If axis is negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the input array.
+
+    Returns
+    -------
+    ret : tvm.Tensor
+    """
+    return cpp.all(data, axis, keepdims)
+
+
 def max(data, axis=None, keepdims=False):
     """Maximum of array elements over a given axis or a list of axes
 
index 1585d87..d3e0bc9 100644 (file)
@@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod")
   *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
   });
 
+TVM_REGISTER_GLOBAL("topi.all")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
+  });
+
 /* Ops from transform.h */
 TVM_REGISTER_GLOBAL("topi.expand_dims")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
index 1882cbd..6e6470d 100644 (file)
@@ -50,6 +50,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
     out_dtype = dtype
     if type == "sum":
         B = topi.sum(A1, axis=axis, keepdims=keepdims)
+    elif type == "all":
+        B = topi.all(A, axis=axis, keepdims=keepdims)
     elif type == "max":
         B = topi.max(A1, axis=axis, keepdims=keepdims)
     elif type == "min":
@@ -74,10 +76,16 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
 
         foo = tvm.build(s, [A, B], device, name=type)
         # Test
-        in_npy = np.random.uniform(size=in_shape).astype(dtype)
-        in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
+        if dtype == 'bool':
+            in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
+        else:
+            in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
+            in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
+
         if type == "sum":
             out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
+        elif type == "all" and dtype == 'bool':
+            out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
         elif type == "max":
             out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
         elif type == "min":
@@ -113,26 +121,37 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
 
 
 def test_reduce_map():
+
     verify_reduce_map_ele(in_shape=(32,),
                           axis=0,
                           keepdims=False,
                           type="argmax")
     verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                        axis=(1, 2, 3),
-                        keepdims=True,
-                        type="sum")
+                          axis=(1, 2, 3),
+                          keepdims=True,
+                          type="sum")
+    verify_reduce_map_ele(in_shape=(2, 3),
+                          axis=None,
+                          keepdims=True,
+                          type="all",
+                          dtype='bool')
     verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
-                        axis=(1,),
-                        keepdims=False,
-                        type="max")
+                          axis=(1,),
+                          keepdims=False,
+                          type="max")
+    verify_reduce_map_ele(in_shape=(32, 128, 24),
+                          axis=None,
+                          keepdims=True,
+                          type="sum")
     verify_reduce_map_ele(in_shape=(32, 128, 24),
-                        axis=None,
-                        keepdims=True,
-                        type="sum")
+                          axis=None,
+                          keepdims=True,
+                          dtype='bool',
+                          type="all")
     verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                        axis=(0, 2),
-                        keepdims=False,
-                        type="min")
+                          axis=(0, 2),
+                          keepdims=False,
+                          type="min")
     verify_reduce_map_ele(in_shape=(32, 128),
                           axis=1,
                           keepdims=True,