[Bugfix] Fixed bug where shifting by out-of-bounds value results in no compute code...
authorpankratz <35379668+dpankratz@users.noreply.github.com>
Mon, 23 Mar 2020 15:47:29 +0000 (09:47 -0600)
committerGitHub <noreply@github.com>
Mon, 23 Mar 2020 15:47:29 +0000 (08:47 -0700)
* Fixed bug where shifting by out-of-bounds RHS values results in LLVM to codegen nothing. Added regression testcase

* Updated testcase to be more precise.

* Fixed testcase

src/tir/ir/op.cc
tests/python/unittest/test_tir_nodes.py

index cf1c24c..4ad244f 100644 (file)
@@ -469,6 +469,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
+      if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+                "Shift amount must be non-negative and less than " << rtype.bits()
+                << " for type " << rtype;
       if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
       if (pb) {
         if (pb->value == 0) return a;
@@ -484,6 +487,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
+      if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+                "Shift amount must be non-negative and less than " << rtype.bits()
+                << " for type " << rtype;
       if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
       if (pb) {
         if (pb->value == 0) return a;
index 7e2c8b5..2904953 100644 (file)
@@ -207,6 +207,23 @@ def test_float_bitwise():
         pass
 
 
+def test_shift_bounds():
+    x = te.var('x')
+    for test in [lambda lhs, rhs : lhs << rhs,
+                    lambda lhs, rhs : lhs >> rhs]:
+        #negative case
+        for testcase in [(x,-1), (x,32)]:
+            try:
+                test(*testcase)
+                assert False
+            except tvm.TVMError:
+                pass
+
+        #positive case
+        for testcase in [(x,0), (x,16), (x,31)]:
+            test(*testcase)
+
+
 def test_divide_by_zero():
     for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
                     lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
@@ -293,6 +310,7 @@ if __name__ == "__main__":
     test_all()
     test_bitwise()
     test_float_bitwise()
+    test_shift_bounds()
     test_divide_by_zero()
     test_isnan()
     test_equality()