[tvm][codegen] Make buffer auto broadcast independent to the order of input args...
authorZhi <5145158+zhiics@users.noreply.github.com>
Mon, 16 Sep 2019 18:07:40 +0000 (11:07 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 16 Sep 2019 18:07:40 +0000 (11:07 -0700)
* [tvm][codegen] Make buffer auto broadcast independent to the order of the input arg

* fix indent

python/tvm/api.py
src/pass/make_api.cc
tests/python/unittest/test_lang_buffer.py

index 6900742..490899e 100644 (file)
@@ -582,7 +582,7 @@ def decl_buffer(shape,
     buffer_type: str, optional, {"", "auto_broadcast"}
         auto_broadcast buffer allows one to implement broadcast computation
         without considering whether dimension size equals to one.
-        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
+        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
 
     Returns
     -------
@@ -601,8 +601,8 @@ def decl_buffer(shape,
         A = tvm.placeholder((m0, m1, m2), name='A')
         B = tvm.placeholder((n0, n1, n2), name='B')
         C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
-        Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast")
-        Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast")
+        Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
+        Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
         s = tvm.create_schedule(C.op)
         fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
         ctx = tvm.cpu(0)
index 0109ad1..27c1513 100644 (file)
@@ -102,6 +102,10 @@ LoweredFunc MakeAPI(Stmt body,
     seq_init.emplace_back(
         MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
   }
+
+  // Save the input variables and buffers that will be bound later.
+  std::vector<std::pair<Var, Var> > var_defs;
+  std::vector<std::pair<Buffer, Var> > buf_defs;
   for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
     Var v_arg = f_arg_decl(i);
     if (i < num_packed_args) {
@@ -139,17 +143,31 @@ LoweredFunc MakeAPI(Stmt body,
     }
     // add checks for functions.
     if (api_args[i].as<Variable>()) {
-      binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true);
+      var_defs.emplace_back(std::make_pair(Var(api_args[i].node_), v_arg));
     } else {
       // Buffer checks
       CHECK(api_args[i].as<BufferNode>())
           << "api_args can only be Buffer or Var";
-      Buffer buf(api_args[i].node_);
-      binder.BindDLTensor(
-          buf, device_type, device_id, v_arg, v_arg->name_hint);
+      buf_defs.emplace_back(std::make_pair(Buffer(api_args[i].node_), v_arg));
     }
   }
 
+  // Arg definitions are defined before buffer binding to avoid the use before
+  // def errors.
+  //
+  // For example, for auto broadcasting, checks are required to guarantee that
+  // either 0 or the original stride will be correctly used. Checks here have
+  // to use the args that may have no let bining yet. Therefore, hoisting let
+  // binding for args before buffer declaration is needed.
+  for (const auto& arg : var_defs) {
+    binder.Bind(arg.first, arg.second, arg.second->name_hint, true);
+  }
+
+  for (const auto& buf_arg : buf_defs) {
+    binder.BindDLTensor(buf_arg.first, device_type, device_id,
+                        buf_arg.second, buf_arg.second->name_hint);
+  }
+
   NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
   n->name = name;
   n->args = args;
index bd45eac..8c7e13a 100644 (file)
@@ -29,6 +29,7 @@ def test_buffer():
     assert Ab.dtype == tvm.float32
     assert tuple(Ab.shape) == (m, n)
 
+
 def test_buffer_access_ptr():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -40,6 +41,7 @@ def test_buffer_access_ptr():
     aptr = Ab.access_ptr("w")
     assert aptr.args[4].value == Buffer.WRITE
 
+
 def test_buffer_access_ptr_offset():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -58,6 +60,7 @@ def test_buffer_access_ptr_offset():
     assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
 
+
 def test_buffer_access_ptr_extent():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -70,6 +73,7 @@ def test_buffer_access_ptr_extent():
     aptr = Ab.access_ptr("rw", offset=100)
     assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)
 
+
 def test_buffer_vload():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -78,6 +82,7 @@ def test_buffer_vload():
     offset = tvm.ir_pass.Simplify(load.index)
     assert tvm.ir_pass.Equal(offset, n * 2 + 103)
 
+
 def test_buffer_index_merge_mult_mod():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -109,6 +114,7 @@ def test_buffer_index_merge_mult_mod():
     index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
     assert_simplified_equal(index_simplified, index_direct)
 
+
 def test_buffer_broadcast():
     m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
     n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
@@ -137,6 +143,48 @@ def test_buffer_broadcast():
     check()
 
 
+def test_bbuffer_roadcast_expr():
+    n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x')
+    n1, m1 = tvm.var('n1'), tvm.var('m1')
+    o0, o1 = tvm.var('o0'), tvm.var('o1')
+
+    A = tvm.placeholder((m0, n0), name='A')
+    B = tvm.placeholder((m1, n1), name='B')
+    C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C')
+
+    Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
+    Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
+    Cc = tvm.decl_buffer(C.shape, C.dtype, name="Cc", buffer_type="auto_broadcast")
+    s = tvm.create_schedule(C.op)
+
+    def check_stride():
+        if not tvm.module.enabled("llvm"):
+            return
+        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
+                         binds={A:Ab, B:Bb, C:Cc})
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
+        fadd(a, b, c, 4, 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+    def check_no_stride():
+        if not tvm.module.enabled("llvm"):
+            return
+        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
+                         binds={A: Ab, B: Bb, C: Cc})
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
+        fadd(a, b, c, 4, 1)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+    check_stride()
+    check_no_stride()
+
+
 if __name__ == "__main__":
     test_buffer()
     test_buffer_access_ptr()
@@ -145,3 +193,4 @@ if __name__ == "__main__":
     test_buffer_vload()
     test_buffer_index_merge_mult_mod()
     test_buffer_broadcast()
+    test_buffer_broadcast_expr()