BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
+ if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+ "Shift amount must be non-negative and less than " << rtype.bits()
+ << " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
+ if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+ "Shift amount must be non-negative and less than " << rtype.bits()
+ << " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
pass
+def test_shift_bounds():
+ x = te.var('x')
+ for test in [lambda lhs, rhs : lhs << rhs,
+ lambda lhs, rhs : lhs >> rhs]:
+ #negative case
+ for testcase in [(x,-1), (x,32)]:
+ try:
+ test(*testcase)
+ assert False
+ except tvm.TVMError:
+ pass
+
+ #positive case
+ for testcase in [(x,0), (x,16), (x,31)]:
+ test(*testcase)
+
+
def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
test_all()
test_bitwise()
test_float_bitwise()
+ test_shift_bounds()
test_divide_by_zero()
test_isnan()
test_equality()