[ARITH] Fix lowering of FloorMod (#4236)
authorSergei Grechanik <grechanik.sergey@huawei.com>
Fri, 1 Nov 2019 15:51:43 +0000 (18:51 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 1 Nov 2019 15:51:43 +0000 (08:51 -0700)
src/pass/lower_intrin.cc
tests/python/unittest/test_codegen_llvm.py

index cc51c66..c2a2fe6 100644 (file)
@@ -77,7 +77,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
     if (op == nullptr) return ret;
     int shift;
     const DataType& dtype = op->type;
-    CHECK(dtype.is_int() || !dtype.is_uint());
+    CHECK(dtype.is_int() || dtype.is_uint());
 
     if (support_bitwise_op_ &&
         is_const_power_of_two_integer(op->b, &shift)) {
@@ -124,7 +124,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
     // Lower floordiv to native truncdiv.
     int shift;
     const DataType& dtype = op->type;
-    CHECK(dtype.is_int() || !dtype.is_uint());
+    CHECK(dtype.is_int() || dtype.is_uint());
 
     if (support_bitwise_op_ &&
         is_const_power_of_two_integer(op->b, &shift)) {
@@ -136,8 +136,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
 
     if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
       // Common pass, positive divisor
-      if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
-          analyzer_->CanProveGreaterEqual(e, 0)) {
+      if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
         return truncmod(op->a, op->b);
       } else {
         DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
index f4401d0..0e595cd 100644 (file)
@@ -406,40 +406,103 @@ def test_alignment():
             assert "align 32" in l
 
 def test_llvm_div():
-    """Check that the semantics of div and mod is the same as in C/C++"""
-    def check_div(start, end, divisor, dtype):
-        T = tvm.compute((end - start,),
-                        lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
-        s = tvm.create_schedule([T.op])
-        f = tvm.build(s, [T], "llvm")
-        a = tvm.nd.empty((end - start,), dtype)
-        f(a)
-        ref = [int(float(i)/divisor) for i in range(start, end)]
-        tvm.testing.assert_allclose(a.asnumpy(), ref)
-
-    def check_mod(start, end, divisor, dtype):
-        tmod = tvm.truncmod
-        T = tvm.compute((end - start,),
-                        lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
-        s = tvm.create_schedule([T.op])
-        f = tvm.build(s, [T], "llvm")
-        a = tvm.nd.empty((end - start,), dtype)
-        f(a)
-        ref = [int(math.fmod(i, divisor)) for i in range(start, end)]
-        tvm.testing.assert_allclose(a.asnumpy(), ref)
-
-    def check_llvm(start, end, divisor, dtype):
-        check_div(start, end, divisor, dtype)
-        check_mod(start, end, divisor, dtype)
-
-    for d in range(-5, 6):
-        if d != 0:
-            # Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
-            check_llvm(-11, 11, d, 'int32')
-            check_llvm(-11, 11, d, 'int8')
-            if d > 0:
-                check_llvm(123, 133, d, 'uint8')
-                check_llvm(0, 256, d, 'uint8')
+    """Check that the semantics of div and mod is correct"""
+    def check(start, end, dstart, dend, dtype, floor_div=False):
+        div = tvm.floordiv if floor_div else tvm.truncdiv
+        mod = tvm.floormod if floor_div else tvm.truncmod
+
+        # A are dividends, B are divisors. Note that we add 1 to make include end in the range.
+        A = tvm.placeholder((end - start + 1,), name="A", dtype=dtype)
+        B = tvm.placeholder((dend - dstart + 1,), name="B", dtype=dtype)
+        # We clip values with min and max so that simplifiers know the ranges of values
+        clipa = lambda x: tvm.min(tvm.const(end, dtype), tvm.max(tvm.const(start, dtype), x))
+        clipb = lambda x: tvm.min(tvm.const(dend, dtype), tvm.max(tvm.const(dstart, dtype), x))
+        # If the range is just a single point, use the constant itself
+        if start == end:
+            clipa = lambda x: tvm.const(start, dtype)
+        if dstart == dend:
+            clipb = lambda x: tvm.const(dstart, dtype)
+        # D are division results and M are modulo results
+        [D, M] = tvm.compute((end - start + 1, dend - dstart + 1),
+                             lambda i, j: (div(clipa(A[i]), clipb(B[j])),
+                                          mod(clipa(A[i]), clipb(B[j]))))
+
+        s = tvm.create_schedule([D.op, M.op])
+        f = tvm.build(s, [A, B, D, M], "llvm")
+
+        # Fill input arrays with values
+        A_arr = tvm.nd.empty((end - start + 1,), dtype)
+        B_arr = tvm.nd.empty((dend - dstart + 1,), dtype)
+        A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype))
+        B_np = np.arange(dstart, dend + 1, dtype=dtype)
+        # If the range of the divisor contains 0, replace it with 1 to avoid division by zero
+        if dend >= 0 and dstart <= 0:
+            B_np[-dstart] = 1
+        B_arr.copyfrom(B_np)
+        D_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
+        M_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
+
+        # Run the function and convert the results to numpy
+        f(A_arr, B_arr, D_arr, M_arr)
+        D_arr = D_arr.asnumpy()
+        M_arr = M_arr.asnumpy()
+
+        # This helper just prints additional info on failure
+        def _show_info():
+            print("dtype: {}".format(dtype))
+            print("dividend range: [{}, {}]".format(start, end))
+            print("divisor range: [{}, {}]".format(dstart, dend))
+            lowered = tvm.lower(s, [A, B, D, M], simple_mode=True)
+            print("Lowered code:")
+            print(lowered)
+
+        # Check that the computed values are correct
+        for i in range(start, end + 1):
+            for j in range(dstart, dend + 1):
+                if j == 0:
+                    continue
+
+                if floor_div:
+                    dref = i // j
+                    mref = i % j
+                else:
+                    dref = int(float(i) / j)
+                    mref = int(math.fmod(i, j))
+
+                if D_arr[i - start, j - dstart] != dref:
+                    _show_info()
+                    raise AssertionError("Incorrect division result: {}({}, {}) is {} "
+                                         "but should be {}".format(div.__name__, i, j,
+                                                                   D_arr[i - start, j - dstart],
+                                                                   dref))
+                if M_arr[i - start, j - dstart] != mref:
+                    _show_info()
+                    raise AssertionError("Incorrect modulo result: {}({}, {}) is {} "
+                                         "but should be {}".format(mod.__name__, i, j,
+                                                                   M_arr[i - start, j - dstart],
+                                                                   mref))
+
+    # Try different ranges to cover different cases
+    for start, end in [(-12, -12), (-11, -1), (-11,  0), (0, 0),
+                       ( 12,  12), (  1, 11), (  0, 11), (-11, 11)]:
+        for dstart, dend in [(-11, -1), (-11,  0), (-4, -4), (-2, -2),
+                             (  1, 11), (  0, 11), ( 4,  4), ( 2,  2), (-11, 11)]:
+                if end < start or dend < dstart or (dend == 0 and dstart == 0):
+                    continue
+                check(start, end, dstart, dend, 'int32', floor_div=False)
+                check(start, end, dstart, dend, 'int32', floor_div=True)
+                check(start, end, dstart, dend, 'int8', floor_div=False)
+                check(start, end, dstart, dend, 'int8', floor_div=True)
+                if start >= 0 and dstart >= 0:
+                    check(start, end, dstart, dend, 'uint32', floor_div=False)
+                    check(start, end, dstart, dend, 'uint32', floor_div=True)
+
+    # Additional tests for uint8
+    for dstart, dend in [(0, 11), (1, 11), (2, 2), (4, 4)]:
+        check(123, 133, dstart, dend, 'uint8', floor_div=False)
+        check(123, 133, dstart, dend, 'uint8', floor_div=True)
+        check(0, 255, dstart, dend, 'uint8', floor_div=False)
+        check(0, 255, dstart, dend, 'uint8', floor_div=True)
 
 def test_llvm_fp_math():
     def check_llvm_reciprocal(n):