REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
+REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
+TVM_REGISTER_API("make._OpIfThenElse")
+.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) {
+ return if_then_else(cond, true_value, false_value);
+});
} // namespace ir
} // namespace tvm
# under the License.
import tvm
+def check_throws(f):
+ try:
+ f()
+ except tvm.TVMError:
+ pass
+ else:
+ raise AssertionError("Should have raised an exception but didn't.")
+
+
def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x, "int32") for x in args])
assert isinstance((1 / x), tvm.expr.Div)
def test_const_fold3():
- def check_throws(f):
- try:
- f()
- except tvm.TVMError:
- pass
- else:
- raise AssertionError("Should have raised an exception but didn't.")
-
# Test that using ints with logic operations is forbidden
x = tvm.var("x")
for val in [0, 1]:
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
+def test_binary_dtype_match():
+ def verify_general_dtype_support(f, is_conditional=False):
+ rules = [[('bool', 'int32'), 'int32'],
+ [('int32', 'float32'), 'float32'],
+ [('int32', 'int64'), 'int64'],
+ [('uint32', 'int32'), 'int32']]
+ for (lhs_dtype, rhs_dtype), out_dtype in rules:
+ lhs = tvm.var('lhs', dtype=lhs_dtype)
+ rhs = tvm.var('rhs', dtype=rhs_dtype)
+ out = f(lhs, rhs)
+ if not is_conditional:
+ assert out.dtype == out_dtype
+ else:
+ assert out.dtype == 'bool'
+ if hasattr(out, 'a'):
+ assert out.a.dtype == out_dtype
+ assert out.b.dtype == out_dtype
+ elif hasattr(out, 'args'):
+ # CallOp
+ assert out.args[0].dtype == out_dtype
+ assert out.args[1].dtype == out_dtype
+ else:
+ raise ValueError('Unknown binary op format!')
+
+ def verify_callop_float_only(f):
+ for lhs_dtype in ['int32', 'float32', 'float64']:
+ for rhs_dtype in ['int32', 'float32', 'float64']:
+ lhs = tvm.var('lhs', dtype=lhs_dtype)
+ rhs = tvm.var('rhs', dtype=rhs_dtype)
+ if 'float' not in lhs_dtype and 'float' not in rhs_dtype:
+ check_throws(lambda: f(lhs, rhs))
+ elif 'float' in lhs_dtype and 'float' in rhs_dtype and lhs_dtype != rhs_dtype:
+ check_throws(lambda: f(lhs, rhs))
+ elif 'float' in lhs_dtype:
+ out = f(lhs, rhs)
+ assert out.dtype == lhs_dtype
+ assert out.args[0].dtype == lhs_dtype
+ assert out.args[1].dtype == lhs_dtype
+ else:
+ out = f(lhs, rhs)
+ assert out.dtype == rhs_dtype
+ assert out.args[0].dtype == rhs_dtype
+ assert out.args[1].dtype == rhs_dtype
+
+ verify_general_dtype_support(lambda a, b: a + b)
+ verify_general_dtype_support(lambda a, b: a * b)
+ verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True)
+ verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
+ verify_callop_float_only(lambda a, b: tvm.power(a, b))
+
+
+def test_if_then_else():
+ cases = [[(tvm.var('cond', dtype='bool'), 'bool', 'int32'), 'int32'],
+ [(True, 'int32', 'float32'), 'float32'],
+ [(False, 'int32', 'int64'), 'int64'],
+ [(tvm.var('cond', dtype='bool'), 'uint32', 'int32'), 'int32'],
+ [(tvm.var('cond', dtype='int32'), 'uint32', 'int32'), 'int32']]
+ for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
+ lhs = tvm.var('lhs', dtype=lhs_dtype)
+ rhs = tvm.var('rhs', dtype=rhs_dtype)
+ if cond is True or cond is False:
+ out = tvm.if_then_else(cond, lhs, rhs)
+ out2 = tvm.if_then_else(not cond, rhs, lhs)
+ out3 = tvm.if_then_else(not cond, lhs, rhs)
+ assert tvm.ir_pass.Equal(out, out2) == 1
+ if cond:
+ assert tvm.ir_pass.Equal(out, lhs.astype(out_dtype)) == 1
+ assert tvm.ir_pass.Equal(out3, rhs.astype(out_dtype)) == 1
+ else:
+ assert tvm.ir_pass.Equal(out, rhs.astype(out_dtype)) == 1
+ assert tvm.ir_pass.Equal(out3, lhs.astype(out_dtype)) == 1
+ elif cond.dtype == 'bool':
+ out = tvm.if_then_else(cond, lhs, rhs)
+ assert out.dtype == out_dtype
+ assert out.args[1].dtype == out_dtype
+ assert out.args[2].dtype == out_dtype
+ elif cond.dtype != 'bool':
+ check_throws(lambda: tvm.if_then_else(cond, lhs, rhs))
+ else:
+ raise ValueError('Unknown combinations')
+
+
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
+ test_binary_dtype_match()
+ test_if_then_else()