[Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (#4205)
authorJon Soifer <soiferj@gmail.com>
Wed, 30 Oct 2019 18:43:09 +0000 (11:43 -0700)
committerJared Roesch <roeschinc@gmail.com>
Wed, 30 Oct 2019 18:43:09 +0000 (11:43 -0700)
* Add support for Any op

* Support ONNX frontend

* Add doc

* Add to relay docs

* Dummy change to retrigger CI

17 files changed:
docs/api/python/topi.rst
docs/frontend/tensorflow.rst
docs/langref/relay_op.rst
include/tvm/expr_operator.h
python/tvm/relay/frontend/onnx.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/op/_reduce.py
python/tvm/relay/op/reduce.py
src/lang/expr_operator.cc
src/relay/op/tensor/reduce.cc
tests/python/frontend/onnx/test_forward.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_op_level4.py
topi/include/topi/reduction.h
topi/python/topi/reduction.py
topi/src/topi.cc
topi/tests/python/test_topi_reduce.py

index 3483668..0e203c1 100644 (file)
@@ -91,6 +91,7 @@ List of operators
    topi.greater_equal
    topi.less_equal
    topi.all
+   topi.any
    topi.logical_and
    topi.logical_or
    topi.logical_not
@@ -151,6 +152,7 @@ topi
 .. autofunction:: topi.full
 .. autofunction:: topi.full_like
 .. autofunction:: topi.all
+.. autofunction:: topi.any
 .. autofunction:: topi.max
 .. autofunction:: topi.sum
 .. autofunction:: topi.min
index 827f5d6..8782888 100644 (file)
@@ -116,6 +116,7 @@ Supported Ops
 - Abs
 - Add
 - All
+- Any
 - ArgMax
 - ArgMin
 - AvgPool
index 57325b5..db74120 100644 (file)
@@ -137,6 +137,7 @@ This level enables additional math and transform operators.
    tvm.relay.less
    tvm.relay.less_equal
    tvm.relay.all
+   tvm.relay.any
    tvm.relay.logical_and
    tvm.relay.logical_or
    tvm.relay.logical_not
@@ -300,6 +301,7 @@ Level 4 Definitions
 .. autofunction:: tvm.relay.less
 .. autofunction:: tvm.relay.less_equal
 .. autofunction:: tvm.relay.all
+.. autofunction:: tvm.relay.any
 .. autofunction:: tvm.relay.logical_and
 .. autofunction:: tvm.relay.logical_or
 .. autofunction:: tvm.relay.logical_not
index adc77a8..625ee8e 100644 (file)
@@ -520,6 +520,13 @@ TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
 TVM_DLL Expr all(Expr source, Array<IterVar> axis);
 
 /*!
+ * \brief logical Or of of source expression over axis
+ * \param source The source expression.
+ * \param axis List of iteration variables that will be used for reduction.
+ */
+TVM_DLL Expr any(Expr source, Array<IterVar> axis);
+
+/*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
index 41fafbc..a28b8f6 100644 (file)
@@ -989,6 +989,12 @@ class Where(OnnxOpConverter):
     def _impl_v9(cls, inputs, attr, params):
         return _op.where(inputs[0], inputs[1], inputs[2])
 
+class Or(Elemwise):
+    """ Operator converter for Or.
+    """
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        return _op.logical_or(inputs[0], inputs[1])
 
 # compatible operators that do NOT require any conversion.
 _identity_list = []
@@ -1111,7 +1117,8 @@ def _get_convert_map(opset):
         'And': And.get_converter(opset),
         'Tile': Tile.get_converter(opset),
         'Erf': Erf.get_converter(opset),
-        'Where': Where.get_converter(opset)
+        'Where': Where.get_converter(opset),
+        'Or': Or.get_converter(opset)
     }
 
 
index 2ef8d15..648d7f4 100644 (file)
@@ -1330,6 +1330,7 @@ _convert_map = {
     'Abs'                               : AttrCvt('abs'),
     'Add'                               : _elemwise('add'),
     'All'                               : _reduce('all'),
+    'Any'                               : _reduce('any'),
     'ArgMax'                            : _argx(_op.argmax, 'argmax'),
     'ArgMin'                            : _argx(_op.argmin, 'argmin'),
     'Assert'                            : _assert(),
index 845ec4b..06d0d66 100644 (file)
@@ -31,6 +31,7 @@ _reg.register_schedule("argmax", _schedule_reduce)
 _reg.register_schedule("argmin", _schedule_reduce)
 _reg.register_schedule("sum", _schedule_reduce)
 _reg.register_schedule("all", _schedule_reduce)
+_reg.register_schedule("any", _schedule_reduce)
 _reg.register_schedule("max", _schedule_reduce)
 _reg.register_schedule("min", _schedule_reduce)
 _reg.register_schedule("prod", _schedule_reduce)
index 49193fd..baf896e 100644 (file)
@@ -166,6 +166,58 @@ def all(data, axis=None, keepdims=False, exclude=False):
     return _make.all(data, axis, keepdims, exclude)
 
 
+def any(data, axis=None, keepdims=False, exclude=False):
+    """Computes the logical OR of boolean array elements over given axes.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input boolean tensor
+
+    axis : None or int or tuple of int
+        Axis or axes along which a sum is performed. The default, axis=None,
+        will sum all of the elements of the input array. If axis is
+        negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as
+        dimensions with size one. With this option, the result will broadcast
+        correctly against the input array.
+
+    exclude : bool
+        If `exclude` is true, reduction will be performed on the axes that are
+        NOT in axis instead.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+    data = relay.Constant(tvm.nd.array([[[ True,  True,  True],
+                                         [ True,  True,  True],
+                                         [False,  True, False]],
+                                        [[ True, False, False],
+                                         [ True,  True, False],
+                                         [False,  True,  True]]]))
+
+    relay.any(data, axis=1)
+    # [[True, True, True],
+    # [True,  True, True]]
+
+    relay.any(data, axis=0)
+    # [[ True, True, True],
+    # [ True,  True, True],
+    # [False,  True, True]]
+
+    """
+    axis = [axis] if isinstance(axis, int) else axis
+    return _make.any(data, axis, keepdims, exclude)
+
+
 def max(data, axis=None, keepdims=False, exclude=False):
     """ Computes the max of array elements over given axes.
 
index 9c9100b..220d437 100644 (file)
@@ -486,6 +486,16 @@ Expr all(Expr source, Array<IterVar> rdom) {
   return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
+Expr any(Expr source, Array<IterVar> rdom) {
+  CHECK(source.type().is_bool());
+  Var x("x", source.type()), y("y", source.type());
+  Expr result = ir::Or::make(x, y);
+  Expr identity_element = make_const(source.type(), false);
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+}
+
 Expr max(Expr source, Array<IterVar> rdom) {
   Var x("x", source.type()), y("y", source.type());
   Expr result = ir::Max::make(x, y);
index 51714bd..63524bc 100644 (file)
@@ -420,6 +420,43 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
+Array<Tensor> AnyCompute(const Attrs& attrs,
+                         const Array<Tensor>& inputs,
+                         const Type& out_type,
+                         const Target& target) {
+  return ReduceCompute(attrs, inputs, out_type, target, topi::any);
+}
+
+
+RELAY_REGISTER_REDUCE_OP("any")
+.describe(R"code(Computes the logical OR of boolean array elements over given axes.
+
+Example::
+
+  data = [[[ True,  True,  True],
+           [ True,  True,  True],
+           [False,  True, False]],
+          [[ True, False, False],
+           [ True,  True, False],
+           [False,  True,  True]]]
+
+  any(data, axis=1)
+  [[True,  True, True],
+   [True,  True, True]]
+
+  any(data, axis=0)
+  [[ True,  True, True],
+   [ True,  True, True],
+   [False,  True, True]]
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<ReduceAttrs>()
+.set_support_level(4)
+.add_type_rel("Reduce", ReduceRel)
+.set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
+.set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
+
 Array<Tensor> MaxCompute(const Attrs& attrs,
                          const Array<Tensor>& inputs,
                          const Type& out_type,
index 2d2265b..5dfaee4 100644 (file)
@@ -1601,6 +1601,53 @@ def test_where():
     verify_where(condition, x, y, TensorProto.FLOAT, outdata)
 
 
+def verify_or(indata, dtype):
+    x = indata[0].astype(dtype)
+    y = indata[1].astype(dtype)
+    outdata = np.logical_or(x, y)
+
+    node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], )
+
+    graph = helper.make_graph([node],
+                              'or_test',
+                              inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
+                                      helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name='or_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
+        tvm.testing.assert_allclose(outdata, tvm_out)
+
+
+def test_or():
+    # 2d
+    x = (np.random.randn(3, 4) > 0)
+    y = (np.random.randn(3, 4) > 0)
+    verify_or(indata=[x, y], dtype=bool)
+
+    # 3d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(3, 4, 5) > 0)
+    verify_or(indata=[x, y], dtype=bool)
+
+    # 4d
+    x = (np.random.randn(3, 4, 5, 6) > 0)
+    y = (np.random.randn(3, 4, 5, 6) > 0)
+    verify_or(indata=[x, y], dtype=bool)
+
+    # 3d vs 1d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(5) > 0)
+    verify_or(indata=[x, y], dtype=bool)
+
+    # 3d vs 2d
+    x = (np.random.randn(3, 4, 5) > 0)
+    y = (np.random.randn(4, 5) > 0)
+    verify_or(indata=[x, y], dtype=bool)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -1651,3 +1698,4 @@ if __name__ == '__main__':
     test_tile()
     test_erf()
     test_where()
+    test_or()
index 11c6a7b..88787ef 100644 (file)
@@ -2198,7 +2198,7 @@ def test_forward_size():
     check_size((10,))
 
 #######################################################################
-# All, Max, Min
+# All, Any, Max, Min
 # -------------
 def test_forward_reduce_all():
     """Test the All operator."""
@@ -2208,6 +2208,14 @@ def test_forward_reduce_all():
     tf.reduce_all(in_data, name="all")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
 
+def test_forward_reduce_any():
+    """Test the Any operator."""
+    np_data = np.random.choice([True, False], size=(5, 7, 11))
+    tf.reset_default_graph()
+    in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
+    tf.reduce_any(in_data, name="any")
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
+
 def test_forward_reduce_max():
     def check_max(ishape, axis, keepdims, dtype):
         tf.reset_default_graph()
@@ -2432,7 +2440,7 @@ if __name__ == '__main__':
     test_forward_mean()
     test_forward_reduce_prod()
     test_forward_reduce_all()
-    test_forward_reduce_max()
+    test_forward_reduce_any()
     test_forward_reduce_min()
 
     # General
index c34dddf..6a8a678 100644 (file)
@@ -145,7 +145,7 @@ def test_where():
 def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
     test_func = funcs[0]
     ref_func = funcs[1]
-    dtype = "bool" if ref_func in [np.all] else dtype
+    dtype = "bool" if ref_func in [np.all, np.any] else dtype
 
     x = relay.var("x", relay.TensorType(data, dtype))
     z = test_func(x, axis, keepdims, exclude)
@@ -207,6 +207,7 @@ def test_reduce_functions():
                  [relay.std, np.std],
                  [relay.prod, np.prod],
                  [relay.all, np.all],
+                 [relay.any, np.any],
                  [relay.argmin, _with_keepdims(np.argmin)],
                  [relay.argmax, _with_keepdims(np.argmax)]]:
         verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
index 14dec77..b703677 100644 (file)
@@ -391,6 +391,27 @@ inline Tensor all(const Tensor& data,
 }
 
 /*!
+* \brief Creates an operation that computes the logical OR of elements
+* over a given axis
+*
+* \param data The input boolean tensor
+* \param axis The axes to reduce. If axis is empty, the operation will
+* perform logical OR over all elements of the array.
+* \param keepdims If this is set to true, the axes which are reduced are
+* left in the result as dimensions with size one. This enables the result
+* to broadcast correctly against the input array.
+* \param atleast1d Whether the output need to be atleast1d.
+*
+* \return A Tensor whose op member is the all operation
+*/
+inline Tensor any(const Tensor& data,
+                  const Array<Integer>& axis,
+                  bool keepdims = false,
+                  bool atleast1d = false) {
+  return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
+}
+
+/*!
 * \brief Creates an operation that finds the minimum of elements over
 * a given axis.
 *
index 5079bf4..7c4e059 100644 (file)
@@ -90,6 +90,31 @@ def all(data, axis=None, keepdims=False):
     return cpp.all(data, axis, keepdims)
 
 
+def any(data, axis=None, keepdims=False):
+    """Logical OR of array elements over a given axis or a list of axes
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tvm boolean tensor
+
+    axis : None or int or tuple of int
+        Axis or axes along which a logical OR is performed.
+        The default, axis=None, will perform logical OR over all elements of the input array.
+        If axis is negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the input array.
+
+    Returns
+    -------
+    ret : tvm.Tensor
+    """
+    return cpp.any(data, axis, keepdims)
+
+
 def max(data, axis=None, keepdims=False):
     """Maximum of array elements over a given axis or a list of axes
 
index a0700bf..01fc598 100644 (file)
@@ -300,6 +300,11 @@ TVM_REGISTER_GLOBAL("topi.all")
   *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
   });
 
+TVM_REGISTER_GLOBAL("topi.any")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
+  });
+
 /* Ops from transform.h */
 TVM_REGISTER_GLOBAL("topi.expand_dims")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
index 6e6470d..d266cfc 100644 (file)
@@ -52,6 +52,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
         B = topi.sum(A1, axis=axis, keepdims=keepdims)
     elif type == "all":
         B = topi.all(A, axis=axis, keepdims=keepdims)
+    elif type == "any":
+        B = topi.any(A, axis=axis, keepdims=keepdims)
     elif type == "max":
         B = topi.max(A1, axis=axis, keepdims=keepdims)
     elif type == "min":
@@ -86,6 +88,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
             out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
         elif type == "all" and dtype == 'bool':
             out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
+        elif type == "any" and dtype == "bool":
+            out_npy = in_npy_map.any(axis=axis, keepdims=keepdims)
         elif type == "max":
             out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
         elif type == "min":
@@ -173,6 +177,26 @@ def test_reduce_map():
                           keepdims=True,
                           type="sum",
                           dtype="float64")
+    verify_reduce_map_ele(in_shape=(2, 3),
+                          axis=None,
+                          keepdims=True,
+                          type="any",
+                          dtype="bool")
+    verify_reduce_map_ele(in_shape=(32, 128, 24),
+                          axis=None,
+                          keepdims=True,
+                          type="any",
+                          dtype="bool")
+    verify_reduce_map_ele(in_shape=(1, 4, 7),
+                          axis=1,
+                          keepdims=True,
+                          type="any",
+                          dtype="bool")
+    verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
+                          axis=2,
+                          keepdims=False,
+                          type="any",
+                          dtype="bool")
 
 if __name__ == "__main__":
     test_reduce_map()