From 7ac78d96611ab3a069fae9eb5210e433eb240cc5 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Mon, 20 Jul 2020 17:49:59 -0700 Subject: [PATCH] [DSL/TE] Scalar support for `te.extern` (#6079) * fix make shape with scalar shapes * add test * add test * remove scalar shape assertion * fix the data type for overflow problems * add extra tests Co-authored-by: Ubuntu --- python/tvm/te/operation.py | 2 +- python/tvm/tir/op.py | 1 - src/tir/transforms/lower_tvm_builtin.cc | 20 +++++++++----- tests/python/unittest/test_te_tensor.py | 49 +++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index b03c6f6..168265f 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -208,7 +208,7 @@ def extern(shape, out_buffers=None, tag="", attrs=None): - """Compute several tensor via extern function. + """Compute several tensors via an extern function. Parameters ---------- diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 1078376..b62d6a3 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -28,7 +28,6 @@ from . import _ffi_api def _pack_buffer(buf): """Build intrinsics that packs the buffer. """ - assert buf.shape shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape) strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0 pack_args = [buf.data, diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index f071704..c8df122 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -54,7 +54,8 @@ class BuiltinLower : public StmtExprMutator { stack_value_ = Var("stack_value", DataType::Handle()); stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); - if (max_shape_stack_ != 0) { + // create a shape var if any shape is made (including scalar shapes) + if (max_shape_stack_ != -1) { stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { @@ -69,7 +70,7 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt(const Stmt& s) final { auto stmt = StmtExprMutator::VisitStmt(s); - CHECK_EQ(run_shape_stack_, 0); + CHECK_EQ(run_shape_stack_, -1); CHECK_EQ(run_array_stack_, 0); if (prep_seq_.size() != 0) { @@ -156,10 +157,15 @@ class BuiltinLower : public StmtExprMutator { } // call shape PrimExpr MakeShape(const CallNode* op) { - size_t stack_begin = run_shape_stack_; + // if args.size() == 0, it represents a scalar shape () + if (run_shape_stack_ == -1) { + run_shape_stack_ = 0; + } + int64_t stack_begin = run_shape_stack_; run_shape_stack_ += op->args.size(); PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); + // no need to perform any store for a scalar shape for (size_t i = 0; i < op->args.size(); ++i) { prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]), ConstInt32(stack_begin + i), const_true(1))); @@ -206,7 +212,7 @@ class BuiltinLower : public StmtExprMutator { } // call packed. PrimExpr MakeCallPacked(const CallNode* op) { - size_t restore_shape_stack = run_shape_stack_; + int64_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; run_arg_stack_ += op->args.size(); @@ -245,7 +251,7 @@ class BuiltinLower : public StmtExprMutator { } PrimExpr MakeCallTracePacked(const CallNode* op) { - size_t restore_shape_stack = run_shape_stack_; + int64_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; run_arg_stack_ += op->args.size(); @@ -307,11 +313,11 @@ class BuiltinLower : public StmtExprMutator { Var stack_tcode_; Var stack_value_; // The running statistics - uint64_t run_shape_stack_{0}; + int64_t run_shape_stack_{-1}; uint64_t run_array_stack_{0}; uint64_t run_arg_stack_{0}; // statistics of stacks - uint64_t max_shape_stack_{0}; + int64_t max_shape_stack_{-1}; uint64_t max_array_stack_{0}; uint64_t max_arg_stack_{0}; }; diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 8d737c9..662eff0 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import numpy as np from tvm import te from topi.nn.pooling import pool @@ -303,6 +304,52 @@ def test_tensor_pool(): s[P].tensorize(oh, intrin) tvm.lower(s, [A, P]) +def test_tensor_scalar_mixed(): + # test te with tensor and scalar + a = np.array(np.random.uniform(size=(10,)), 'float32') + b = np.array(np.random.uniform(size=(1))[0], 'float32') + c = np.array(np.random.uniform(size=(10,)), 'float32') + + @tvm.register_func("tvm.test_tensor_scalar_scale") + def my_scale(tensor, scalar, out): + out_np = tensor.asnumpy() * scalar.asnumpy() + tvm.nd.array(out_np).copyto(out) + + A = te.placeholder(a.shape, name='A') + B = te.placeholder(b.shape, name='B') + C = te.extern(a.shape, [A, B], + lambda ins, outs: tvm.tir.call_packed( + "tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]), name="C") + s = te.create_schedule(C.op) + f = tvm.build(s, [A, B, C], 'llvm') + + ta = tvm.nd.array(a) + tb = tvm.nd.array(b) + tc = tvm.nd.array(c) + f(ta, tb, tc) + tvm.testing.assert_allclose(a * b, tc.asnumpy()) + + +def test_tensor_scalar(): + # test te with scalar shape + a = np.array(np.random.uniform(size=(1))[0], 'float32') + b = np.array(0.0, 'float32') + + @tvm.register_func("tvm.test_tensor_scalar_copy") + def mycopy(x, y): + x.copyto(y) + + A = te.placeholder(a.shape, name='A') + B = te.extern(a.shape, [A], + lambda ins, outs: tvm.tir.call_packed( + "tvm.test_tensor_scalar_copy", ins[0], outs[0]), name="B") + s = te.create_schedule(B.op) + f = tvm.build(s, [A, B], 'llvm') + + ta = tvm.nd.array(a) + tb = tvm.nd.array(b) + f(ta, tb) + tvm.testing.assert_allclose(ta.asnumpy(), tb.asnumpy()) if __name__ == "__main__": test_rank_zero() @@ -321,3 +368,5 @@ if __name__ == "__main__": test_tuple_inputs() test_tuple_with_different_deps() test_tensor_pool() + test_tensor_scalar() + test_tensor_scalar_mixed() -- 2.7.4