From f9d8d063245598973a495c6daa03317c1451c3af Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 16 Aug 2019 20:28:46 -0700 Subject: [PATCH] Fix ArgBinder assert order (#3794) --- src/pass/arg_binder.cc | 2 +- tests/python/unittest/test_build_lower.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index ff4c77a..8268fc4 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -239,7 +239,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, AssertStmt::make(arith::ComputeReduce(conds, Expr()), stride_err_msg.str(), Evaluate::make(0)); check = IfThenElse::make(Not::make(is_null), check, Stmt()); - init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); + asserts_.emplace_back(Block::make(check, Evaluate::make(0))); } } else if (buffer->buffer_type == kAutoBroadcast) { Type stype = buffer->DefaultIndexType(); diff --git a/tests/python/unittest/test_build_lower.py b/tests/python/unittest/test_build_lower.py index 8600fc2..082b85f 100644 --- a/tests/python/unittest/test_build_lower.py +++ b/tests/python/unittest/test_build_lower.py @@ -32,5 +32,13 @@ def test_lower_rfactor(): s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) fapi = tvm.lower(s, [A, B]) +def test_dependent_output_shape(): + n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x') + A = tvm.placeholder((n, m)) + B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B') + s = tvm.create_schedule(B.op) + mod = tvm.build(s, [A, B, x]) + if __name__ == "__main__": test_lower_rfactor() + test_dependent_output_shape() -- 2.7.4