import tvm
import topi
+
def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
check_device(target)
check_device("sdaccel")
+
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
verify_broadcast_to_ele((), (10,), topi.broadcast_to)
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
+
def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)
+
def test_subtract():
verify_broadcast_binary_ele(
(5, 2, 3), (), topi.subtract, np.subtract)
verify_broadcast_binary_ele(
(1, 32), (64, 32), topi.subtract, np.subtract)
+
def test_multiply():
verify_broadcast_binary_ele(
(5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
+
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
+
def test_maximum_minmum():
verify_broadcast_binary_ele(
(32,), (64, 32), topi.maximum, np.maximum)
verify_broadcast_binary_ele(
(1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
+
def test_power():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
+
def test_mod():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
+
def test_cmp():
# explicit specify the output type
def greater(x, y):
return topi.greater(x, y).astype("int8")
+
def less(x, y):
return topi.less(x, y).astype("int8")
+
def equal(x, y):
return topi.equal(x, y).astype("int8")
+
def not_equal(x, y):
return topi.not_equal(x, y).astype("int8")
+
def greater_equal(x, y):
return topi.greater_equal(x, y).astype("int8")
+
def less_equal(x, y):
return topi.less_equal(x, y).astype("int8")
verify_broadcast_binary_ele(
(7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
+
def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
dtype="int8", rhs_min=0, rhs_max=32)
+def test_logical_single_ele():
+ def test_apply(
+ func,
+ name,
+ f_numpy,
+ indata,
+ dtype="bool",
+ ):
+ # Build the logic and compile the function
+ A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
+ B = func(A)
+ if isinstance(A, tvm.expr.Expr):
+ assert (isinstance(B, tvm.expr.Expr))
+ return
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_broadcast(B)
+ foo = tvm.build(s, [A, B], device, name=name)
+
+ data_npy = indata.astype(A.dtype)
+ data_nd = tvm.nd.array(data_npy, ctx)
+
+ out_npy = f_numpy(indata)
+ out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
+ foo(data_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+ for device in get_all_backend():
+ check_device(device)
+
+ test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1]))
+ test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
+
+
+def test_logical_binary_ele():
+ def test_apply(
+ func,
+ name,
+ f_numpy,
+ lhs,
+ rhs,
+ dtype="bool",
+ ):
+ # Build the logic and compile the function
+ A = (tvm.var("A", dtype=dtype))
+ B = (tvm.var("B", dtype=dtype))
+ C = func(A, B)
+ if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
+ assert (isinstance(C, tvm.expr.Expr))
+ return
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_broadcast(C)
+ foo = tvm.build(s, [A, B, C], device, name=name)
+
+ lhs_nd = tvm.nd.array(lhs, ctx)
+ rhs_nd = tvm.nd.array(rhs, ctx)
+
+ out_npy = f_numpy(lhs, rhs)
+ out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
+ foo(lhs_nd, rhs_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
+
+ for device in get_all_backend():
+ check_device(device)
+
+ test_apply(topi.logical_and, "logical_and", np.logical_and, True, False)
+ test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
+ test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
+ test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
+
+
if __name__ == "__main__":
test_add()
test_shift()
test_maximum_minmum()
test_power()
test_broadcast_to()
+ test_logical_single_ele()
+ test_logical_binary_ele()