[DSL/TE] Scalar support for `te.extern` (#6079)
authorHaibin Lin <linhaibin.eric@gmail.com>
Tue, 21 Jul 2020 00:49:59 +0000 (17:49 -0700)
committerGitHub <noreply@github.com>
Tue, 21 Jul 2020 00:49:59 +0000 (17:49 -0700)
* fix make shape with scalar shapes

* add test

* add test

* remove scalar shape assertion

* fix the data type for overflow problems

* add extra tests

Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
python/tvm/te/operation.py
python/tvm/tir/op.py
src/tir/transforms/lower_tvm_builtin.cc
tests/python/unittest/test_te_tensor.py

index b03c6f6..168265f 100644 (file)
@@ -208,7 +208,7 @@ def extern(shape,
            out_buffers=None,
            tag="",
            attrs=None):
-    """Compute several tensor via extern function.
+    """Compute several tensors via an extern function.
 
     Parameters
     ----------
index 1078376..b62d6a3 100644 (file)
@@ -28,7 +28,6 @@ from . import _ffi_api
 def _pack_buffer(buf):
     """Build intrinsics that packs the buffer.
     """
-    assert buf.shape
     shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
     strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
     pack_args = [buf.data,
index f071704..c8df122 100644 (file)
@@ -54,7 +54,8 @@ class BuiltinLower : public StmtExprMutator {
     stack_value_ = Var("stack_value", DataType::Handle());
     stack_tcode_ = Var("stack_tcode", DataType::Handle());
     stmt = this->VisitStmt(stmt);
-    if (max_shape_stack_ != 0) {
+    // create a shape var if any shape is made (including scalar shapes)
+    if (max_shape_stack_ != -1) {
       stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
     }
     if (max_array_stack_ != 0) {
@@ -69,7 +70,7 @@ class BuiltinLower : public StmtExprMutator {
 
   Stmt VisitStmt(const Stmt& s) final {
     auto stmt = StmtExprMutator::VisitStmt(s);
-    CHECK_EQ(run_shape_stack_, 0);
+    CHECK_EQ(run_shape_stack_, -1);
     CHECK_EQ(run_array_stack_, 0);
 
     if (prep_seq_.size() != 0) {
@@ -156,10 +157,15 @@ class BuiltinLower : public StmtExprMutator {
   }
   // call shape
   PrimExpr MakeShape(const CallNode* op) {
-    size_t stack_begin = run_shape_stack_;
+    // if args.size() == 0, it represents a scalar shape ()
+    if (run_shape_stack_ == -1) {
+      run_shape_stack_ = 0;
+    }
+    int64_t stack_begin = run_shape_stack_;
     run_shape_stack_ += op->args.size();
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
+    // no need to perform any store for a scalar shape
     for (size_t i = 0; i < op->args.size(); ++i) {
       prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
                                    ConstInt32(stack_begin + i), const_true(1)));
@@ -206,7 +212,7 @@ class BuiltinLower : public StmtExprMutator {
   }
   // call packed.
   PrimExpr MakeCallPacked(const CallNode* op) {
-    size_t restore_shape_stack = run_shape_stack_;
+    int64_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
     run_arg_stack_ += op->args.size();
@@ -245,7 +251,7 @@ class BuiltinLower : public StmtExprMutator {
   }
 
   PrimExpr MakeCallTracePacked(const CallNode* op) {
-    size_t restore_shape_stack = run_shape_stack_;
+    int64_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
     run_arg_stack_ += op->args.size();
@@ -307,11 +313,11 @@ class BuiltinLower : public StmtExprMutator {
   Var stack_tcode_;
   Var stack_value_;
   // The running statistics
-  uint64_t run_shape_stack_{0};
+  int64_t run_shape_stack_{-1};
   uint64_t run_array_stack_{0};
   uint64_t run_arg_stack_{0};
   // statistics of stacks
-  uint64_t max_shape_stack_{0};
+  int64_t max_shape_stack_{-1};
   uint64_t max_array_stack_{0};
   uint64_t max_arg_stack_{0};
 };
index 8d737c9..662eff0 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import numpy as np
 from tvm import te
 from topi.nn.pooling import pool
 
@@ -303,6 +304,52 @@ def test_tensor_pool():
     s[P].tensorize(oh, intrin)
     tvm.lower(s, [A, P])
 
+def test_tensor_scalar_mixed():
+    # test te with tensor and scalar
+    a = np.array(np.random.uniform(size=(10,)), 'float32')
+    b = np.array(np.random.uniform(size=(1))[0], 'float32')
+    c = np.array(np.random.uniform(size=(10,)), 'float32')
+
+    @tvm.register_func("tvm.test_tensor_scalar_scale")
+    def my_scale(tensor, scalar, out):
+        out_np = tensor.asnumpy() * scalar.asnumpy()
+        tvm.nd.array(out_np).copyto(out)
+
+    A = te.placeholder(a.shape, name='A')
+    B = te.placeholder(b.shape, name='B')
+    C = te.extern(a.shape, [A, B],
+                  lambda ins, outs: tvm.tir.call_packed(
+                  "tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]), name="C")
+    s = te.create_schedule(C.op)
+    f = tvm.build(s, [A, B, C], 'llvm')
+
+    ta = tvm.nd.array(a)
+    tb = tvm.nd.array(b)
+    tc = tvm.nd.array(c)
+    f(ta, tb, tc)
+    tvm.testing.assert_allclose(a * b, tc.asnumpy())
+
+
+def test_tensor_scalar():
+    # test te with scalar shape
+    a = np.array(np.random.uniform(size=(1))[0], 'float32')
+    b = np.array(0.0, 'float32')
+
+    @tvm.register_func("tvm.test_tensor_scalar_copy")
+    def mycopy(x, y):
+        x.copyto(y)
+
+    A = te.placeholder(a.shape, name='A')
+    B = te.extern(a.shape, [A],
+                  lambda ins, outs: tvm.tir.call_packed(
+                  "tvm.test_tensor_scalar_copy", ins[0], outs[0]), name="B")
+    s = te.create_schedule(B.op)
+    f = tvm.build(s, [A, B], 'llvm')
+
+    ta = tvm.nd.array(a)
+    tb = tvm.nd.array(b)
+    f(ta, tb)
+    tvm.testing.assert_allclose(ta.asnumpy(), tb.asnumpy())
 
 if __name__ == "__main__":
     test_rank_zero()
@@ -321,3 +368,5 @@ if __name__ == "__main__":
     test_tuple_inputs()
     test_tuple_with_different_deps()
     test_tensor_pool()
+    test_tensor_scalar()
+    test_tensor_scalar_mixed()