From 90eee08746d7b96d834331aa910a760451330f7b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 24 Jul 2019 11:48:39 -0700 Subject: [PATCH] [TEST] Fix testcase to make them more compatible to zero-rank (#3612) --- tests/python/unittest/test_codegen_llvm.py | 9 +++++---- topi/python/topi/generic_op_impl.py | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 3026c5e..ed6fedc 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import topi from tvm.contrib import util, clang import numpy as np import ctypes @@ -349,8 +350,8 @@ def test_rank_zero(): A = tvm.placeholder((n, ), name='A') scale = tvm.placeholder((), name='scale') k = tvm.reduce_axis((0, n), name="k") - C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C") - D = tvm.compute((), lambda : C + 1) + C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C") + D = tvm.compute((), lambda : C() + 1) s = tvm.create_schedule(D.op) # build and invoke the kernel. f = tvm.build(s, [A, scale, D], "llvm") @@ -373,8 +374,8 @@ def test_rank_zero_bound_checkers(): A = tvm.placeholder((n, ), name='A') scale = tvm.placeholder((), name='scale') k = tvm.reduce_axis((0, n), name="k") - C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C") - D = tvm.compute((), lambda : C + 1) + C = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k), name="C") + D = tvm.compute((), lambda : C() + 1) s = tvm.create_schedule(D.op) # build and invoke the kernel. f = tvm.build(s, [A, scale, D], "llvm") diff --git a/topi/python/topi/generic_op_impl.py b/topi/python/topi/generic_op_impl.py index b4b719f..ce625bc 100644 --- a/topi/python/topi/generic_op_impl.py +++ b/topi/python/topi/generic_op_impl.py @@ -79,6 +79,8 @@ def _make_bop(broadcast_bop, orig_bop): tvm.Expr (otherwise) The result of {op} operation. """ + print(lhs, type(lhs)) + print(rhs, type(rhs)) if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor): return orig_bop(lhs, rhs) return broadcast_bop(lhs, rhs) -- 2.7.4