[TEST] Fix testcase to make them more compatible to zero-rank (#3612)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 24 Jul 2019 18:48:39 +0000 (11:48 -0700)
committerGitHub <noreply@github.com>
Wed, 24 Jul 2019 18:48:39 +0000 (11:48 -0700)
tests/python/unittest/test_codegen_llvm.py
topi/python/topi/generic_op_impl.py

index 3026c5e..ed6fedc 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import topi
 from tvm.contrib import util, clang
 import numpy as np
 import ctypes
@@ -349,8 +350,8 @@ def test_rank_zero():
         A = tvm.placeholder((n, ), name='A')
         scale = tvm.placeholder((), name='scale')
         k = tvm.reduce_axis((0, n), name="k")
-        C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
-        D = tvm.compute((), lambda : C + 1)
+        C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
+        D = tvm.compute((), lambda : C() + 1)
         s = tvm.create_schedule(D.op)
         # build and invoke the kernel.
         f = tvm.build(s, [A, scale, D], "llvm")
@@ -373,8 +374,8 @@ def test_rank_zero_bound_checkers():
             A = tvm.placeholder((n, ), name='A')
             scale = tvm.placeholder((), name='scale')
             k = tvm.reduce_axis((0, n), name="k")
-            C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
-            D = tvm.compute((), lambda : C + 1)
+            C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C")
+            D = tvm.compute((), lambda : C() + 1)
             s = tvm.create_schedule(D.op)
             # build and invoke the kernel.
             f = tvm.build(s, [A, scale, D], "llvm")
index b4b719f..ce625bc 100644 (file)
@@ -79,6 +79,8 @@ def _make_bop(broadcast_bop, orig_bop):
               tvm.Expr (otherwise)
             The result of {op} operation.
         """
+        print(lhs, type(lhs))
+        print(rhs, type(rhs))
         if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor):
             return orig_bop(lhs, rhs)
         return broadcast_bop(lhs, rhs)