[TOPI] operator support: logical_and, logical_or, logical_not (#3929)
authorNeo Chien <cchung100m@cs.ccu.edu.tw>
Mon, 16 Sep 2019 17:42:47 +0000 (01:42 +0800)
committerYao Wang <kevinthesunwy@gmail.com>
Mon, 16 Sep 2019 17:42:47 +0000 (10:42 -0700)
* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_not

docs/api/python/topi.rst
topi/python/topi/broadcast.py
topi/python/topi/math.py
topi/src/topi.cc
topi/tests/python/test_topi_broadcast.py

index 123c1d0..3b05803 100644 (file)
@@ -175,6 +175,9 @@ topi
 .. autofunction:: topi.topk
 .. autofunction:: topi.sequence_mask
 .. autofunction:: topi.one_hot
+.. autofunction:: topi.logical_and
+.. autofunction:: topi.logical_or
+.. autofunction:: topi.logical_not
 
 topi.nn
 ~~~~~~~
index 0b6dbb5..4075198 100644 (file)
@@ -18,6 +18,7 @@
 from __future__ import absolute_import as _abs
 from .import cpp as _cpp
 
+
 def broadcast_to(data, shape):
     """Broadcast the src to the target shape
 
@@ -341,3 +342,57 @@ def less_equal(lhs, rhs):
         Otherwise returns Tensor.
     """
     return _cpp.less_equal(lhs, rhs)
+
+
+def logical_and(lhs, rhs):
+    """Compute element-wise logical and of data.
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+          The left operand
+    rhs : tvm.Tensor or Expr
+          The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+          Returns Expr if both operands are Expr.
+          Otherwise returns Tensor.
+    """
+    return _cpp.logical_and(lhs, rhs)
+
+
+def logical_or(lhs, rhs):
+    """Compute element-wise logical or of data.
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+          The left operand
+    rhs : tvm.Tensor or Expr
+          The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+          Returns Expr if both operands are Expr.
+          Otherwise returns Tensor.
+    """
+    return _cpp.logical_or(lhs, rhs)
+
+
+def logical_not(data):
+    """Compute element-wise logical not of data.
+
+    Parameters
+    ----------
+    data : tvm.Tensor or Expr
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+          Returns Expr if the operand are Expr.
+          Otherwise returns Tensor.
+    """
+    return _cpp.logical_not(data)
index 6f44b85..5aebb82 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from . import tag
 from . import cpp
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def identity(x):
     """Take identity of input x.
@@ -107,6 +108,7 @@ def tanh(x):
     """
     return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def cos(x):
     """Take cos of input x.
@@ -123,6 +125,7 @@ def cos(x):
     """
     return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def sin(x):
     """Take sin of input x.
@@ -139,6 +142,7 @@ def sin(x):
     """
     return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i)))
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def floor(x):
     """Take floor of input x.
@@ -172,6 +176,7 @@ def ceil(x):
     """
     return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
 
+
 def sign(x):
     """Returns -1, 0, 1 based on sign of x.
 
@@ -187,6 +192,7 @@ def sign(x):
     """
     return cpp.sign(x)
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def trunc(x):
     """Take truncated value of the input of x, element-wise.
@@ -254,6 +260,7 @@ def log(x):
     """
     return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
 
+
 @tvm.tag_scope(tag=tag.ELEMWISE)
 def sqrt(x):
     """Take square root of input x.
@@ -391,6 +398,7 @@ def cast(x, dtype):
             x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
     return tvm.make._cast(dtype, x)
 
+
 def reinterpret(x, dtype):
     """Reinterpret input to specified data type.
 
index 5649ad2..2068f15 100644 (file)
@@ -118,11 +118,6 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target")
       }                                                                 \
     });                                                                 \
 
-TVM_REGISTER_GLOBAL("topi.broadcast_to")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = broadcast_to(args[0], args[1]);
-  });
-
 TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
 TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
 TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
@@ -142,6 +137,16 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
 TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
 TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
 
+TVM_REGISTER_GLOBAL("topi.broadcast_to")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = broadcast_to(args[0], args[1]);
+  });
+
+TVM_REGISTER_GLOBAL("topi.logical_not")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = logical_not(args[0]);
+  });
+
 /* Ops from elemwise.h */
 TVM_REGISTER_GLOBAL("topi.exp")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
index 3701630..ea84a43 100644 (file)
@@ -20,6 +20,7 @@ import numpy as np
 import tvm
 import topi
 
+
 def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
     # Build the logic and compile the function
     A = tvm.placeholder(shape=in_shape, name="A")
@@ -99,18 +100,21 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
         check_device(target)
     check_device("sdaccel")
 
+
 def test_broadcast_to():
     verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
     verify_broadcast_to_ele((), (10,), topi.broadcast_to)
     verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
     verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
 
+
 def test_add():
     verify_broadcast_binary_ele(
         (), (), topi.add, np.add)
     verify_broadcast_binary_ele(
         (5, 2, 3), (2, 1), topi.add, np.add)
 
+
 def test_subtract():
     verify_broadcast_binary_ele(
         (5, 2, 3), (), topi.subtract, np.subtract)
@@ -121,10 +125,12 @@ def test_subtract():
     verify_broadcast_binary_ele(
         (1, 32), (64, 32), topi.subtract, np.subtract)
 
+
 def test_multiply():
     verify_broadcast_binary_ele(
         (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
 
+
 def test_divide():
     verify_broadcast_binary_ele(
         None, (10,), topi.divide, np.divide, rhs_min=0.0001)
@@ -133,32 +139,41 @@ def test_divide():
     verify_broadcast_binary_ele(
         (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
 
+
 def test_maximum_minmum():
     verify_broadcast_binary_ele(
         (32,), (64, 32), topi.maximum, np.maximum)
     verify_broadcast_binary_ele(
         (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
 
+
 def test_power():
     verify_broadcast_binary_ele(
         (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
 
+
 def test_mod():
     verify_broadcast_binary_ele(
         (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
 
+
 def test_cmp():
     # explicit specify the output type
     def greater(x, y):
         return topi.greater(x, y).astype("int8")
+
     def less(x, y):
         return topi.less(x, y).astype("int8")
+
     def equal(x, y):
         return topi.equal(x, y).astype("int8")
+
     def not_equal(x, y):
         return topi.not_equal(x, y).astype("int8")
+
     def greater_equal(x, y):
         return topi.greater_equal(x, y).astype("int8")
+
     def less_equal(x, y):
         return topi.less_equal(x, y).astype("int8")
     verify_broadcast_binary_ele(
@@ -178,6 +193,7 @@ def test_cmp():
         (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
         lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
 
+
 def test_shift():
     # explicit specify the output type
     verify_broadcast_binary_ele(
@@ -193,6 +209,90 @@ def test_shift():
         dtype="int8", rhs_min=0, rhs_max=32)
 
 
+def test_logical_single_ele():
+    def test_apply(
+            func,
+            name,
+            f_numpy,
+            indata,
+            dtype="bool",
+    ):
+        # Build the logic and compile the function
+        A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
+        B = func(A)
+        if isinstance(A, tvm.expr.Expr):
+            assert (isinstance(B, tvm.expr.Expr))
+            return
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.generic.schedule_broadcast(B)
+            foo = tvm.build(s, [A, B], device, name=name)
+
+            data_npy = indata.astype(A.dtype)
+            data_nd = tvm.nd.array(data_npy, ctx)
+
+            out_npy = f_numpy(indata)
+            out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
+            foo(data_nd, out_nd)
+            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+        for device in get_all_backend():
+            check_device(device)
+
+    test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1]))
+    test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
+
+
+def test_logical_binary_ele():
+    def test_apply(
+            func,
+            name,
+            f_numpy,
+            lhs,
+            rhs,
+            dtype="bool",
+    ):
+        # Build the logic and compile the function
+        A = (tvm.var("A", dtype=dtype))
+        B = (tvm.var("B", dtype=dtype))
+        C = func(A, B)
+        if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
+            assert (isinstance(C, tvm.expr.Expr))
+            return
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.generic.schedule_broadcast(C)
+            foo = tvm.build(s, [A, B, C], device, name=name)
+
+            lhs_nd = tvm.nd.array(lhs, ctx)
+            rhs_nd = tvm.nd.array(rhs, ctx)
+
+            out_npy = f_numpy(lhs, rhs)
+            out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
+            foo(lhs_nd, rhs_nd, out_nd)
+            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
+
+        for device in get_all_backend():
+            check_device(device)
+
+    test_apply(topi.logical_and, "logical_and", np.logical_and, True, False)
+    test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
+    test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
+    test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
+
+
 if __name__ == "__main__":
     test_add()
     test_shift()
@@ -204,3 +304,5 @@ if __name__ == "__main__":
     test_maximum_minmum()
     test_power()
     test_broadcast_to()
+    test_logical_single_ele()
+    test_logical_binary_ele()