assert "align 32" in l
def test_llvm_div():
- """Check that the semantics of div and mod is the same as in C/C++"""
- def check_div(start, end, divisor, dtype):
- T = tvm.compute((end - start,),
- lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
- s = tvm.create_schedule([T.op])
- f = tvm.build(s, [T], "llvm")
- a = tvm.nd.empty((end - start,), dtype)
- f(a)
- ref = [int(float(i)/divisor) for i in range(start, end)]
- tvm.testing.assert_allclose(a.asnumpy(), ref)
-
- def check_mod(start, end, divisor, dtype):
- tmod = tvm.truncmod
- T = tvm.compute((end - start,),
- lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
- s = tvm.create_schedule([T.op])
- f = tvm.build(s, [T], "llvm")
- a = tvm.nd.empty((end - start,), dtype)
- f(a)
- ref = [int(math.fmod(i, divisor)) for i in range(start, end)]
- tvm.testing.assert_allclose(a.asnumpy(), ref)
-
- def check_llvm(start, end, divisor, dtype):
- check_div(start, end, divisor, dtype)
- check_mod(start, end, divisor, dtype)
-
- for d in range(-5, 6):
- if d != 0:
- # Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
- check_llvm(-11, 11, d, 'int32')
- check_llvm(-11, 11, d, 'int8')
- if d > 0:
- check_llvm(123, 133, d, 'uint8')
- check_llvm(0, 256, d, 'uint8')
+ """Check that the semantics of div and mod is correct"""
+ def check(start, end, dstart, dend, dtype, floor_div=False):
+ div = tvm.floordiv if floor_div else tvm.truncdiv
+ mod = tvm.floormod if floor_div else tvm.truncmod
+
+ # A are dividends, B are divisors. Note that we add 1 to make include end in the range.
+ A = tvm.placeholder((end - start + 1,), name="A", dtype=dtype)
+ B = tvm.placeholder((dend - dstart + 1,), name="B", dtype=dtype)
+ # We clip values with min and max so that simplifiers know the ranges of values
+ clipa = lambda x: tvm.min(tvm.const(end, dtype), tvm.max(tvm.const(start, dtype), x))
+ clipb = lambda x: tvm.min(tvm.const(dend, dtype), tvm.max(tvm.const(dstart, dtype), x))
+ # If the range is just a single point, use the constant itself
+ if start == end:
+ clipa = lambda x: tvm.const(start, dtype)
+ if dstart == dend:
+ clipb = lambda x: tvm.const(dstart, dtype)
+ # D are division results and M are modulo results
+ [D, M] = tvm.compute((end - start + 1, dend - dstart + 1),
+ lambda i, j: (div(clipa(A[i]), clipb(B[j])),
+ mod(clipa(A[i]), clipb(B[j]))))
+
+ s = tvm.create_schedule([D.op, M.op])
+ f = tvm.build(s, [A, B, D, M], "llvm")
+
+ # Fill input arrays with values
+ A_arr = tvm.nd.empty((end - start + 1,), dtype)
+ B_arr = tvm.nd.empty((dend - dstart + 1,), dtype)
+ A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype))
+ B_np = np.arange(dstart, dend + 1, dtype=dtype)
+ # If the range of the divisor contains 0, replace it with 1 to avoid division by zero
+ if dend >= 0 and dstart <= 0:
+ B_np[-dstart] = 1
+ B_arr.copyfrom(B_np)
+ D_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
+ M_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
+
+ # Run the function and convert the results to numpy
+ f(A_arr, B_arr, D_arr, M_arr)
+ D_arr = D_arr.asnumpy()
+ M_arr = M_arr.asnumpy()
+
+ # This helper just prints additional info on failure
+ def _show_info():
+ print("dtype: {}".format(dtype))
+ print("dividend range: [{}, {}]".format(start, end))
+ print("divisor range: [{}, {}]".format(dstart, dend))
+ lowered = tvm.lower(s, [A, B, D, M], simple_mode=True)
+ print("Lowered code:")
+ print(lowered)
+
+ # Check that the computed values are correct
+ for i in range(start, end + 1):
+ for j in range(dstart, dend + 1):
+ if j == 0:
+ continue
+
+ if floor_div:
+ dref = i // j
+ mref = i % j
+ else:
+ dref = int(float(i) / j)
+ mref = int(math.fmod(i, j))
+
+ if D_arr[i - start, j - dstart] != dref:
+ _show_info()
+ raise AssertionError("Incorrect division result: {}({}, {}) is {} "
+ "but should be {}".format(div.__name__, i, j,
+ D_arr[i - start, j - dstart],
+ dref))
+ if M_arr[i - start, j - dstart] != mref:
+ _show_info()
+ raise AssertionError("Incorrect modulo result: {}({}, {}) is {} "
+ "but should be {}".format(mod.__name__, i, j,
+ M_arr[i - start, j - dstart],
+ mref))
+
+ # Try different ranges to cover different cases
+ for start, end in [(-12, -12), (-11, -1), (-11, 0), (0, 0),
+ ( 12, 12), ( 1, 11), ( 0, 11), (-11, 11)]:
+ for dstart, dend in [(-11, -1), (-11, 0), (-4, -4), (-2, -2),
+ ( 1, 11), ( 0, 11), ( 4, 4), ( 2, 2), (-11, 11)]:
+ if end < start or dend < dstart or (dend == 0 and dstart == 0):
+ continue
+ check(start, end, dstart, dend, 'int32', floor_div=False)
+ check(start, end, dstart, dend, 'int32', floor_div=True)
+ check(start, end, dstart, dend, 'int8', floor_div=False)
+ check(start, end, dstart, dend, 'int8', floor_div=True)
+ if start >= 0 and dstart >= 0:
+ check(start, end, dstart, dend, 'uint32', floor_div=False)
+ check(start, end, dstart, dend, 'uint32', floor_div=True)
+
+ # Additional tests for uint8
+ for dstart, dend in [(0, 11), (1, 11), (2, 2), (4, 4)]:
+ check(123, 133, dstart, dend, 'uint8', floor_div=False)
+ check(123, 133, dstart, dend, 'uint8', floor_div=True)
+ check(0, 255, dstart, dend, 'uint8', floor_div=False)
+ check(0, 255, dstart, dend, 'uint8', floor_div=True)
def test_llvm_fp_math():
def check_llvm_reciprocal(n):