[Codegen] Support broadcast op with symbolic shape (#3389)
authorYizhi Liu <liuyizhi@apache.org>
Tue, 2 Jul 2019 07:20:26 +0000 (00:20 -0700)
committerGitHub <noreply@github.com>
Tue, 2 Jul 2019 07:20:26 +0000 (00:20 -0700)
* [Codegen] Support broadcast op with symbolic shape

* fix case where last dim = 1

* use enum; simplify stride calculation; improve doc

* fix lint

* improve py doc

include/tvm/buffer.h
python/tvm/api.py
src/api/api_lang.cc
src/codegen/build_module.cc
src/lang/buffer.cc
src/pass/arg_binder.cc
src/pass/inject_copy_intrin.cc
src/pass/storage_flatten.cc
tests/python/unittest/test_lang_buffer.py
topi/include/topi/detail/extern.h

index ed4ac5e..1233e9b 100644 (file)
@@ -36,10 +36,11 @@ namespace tvm {
 // Internal node container Buffer
 class BufferNode;
 
-/*! \brief memory access kind */
-enum class AccessMask : int {
-  kRead = 1,
-  kWrite = 2
+/*! \brief buffer type */
+enum BufferType : int {
+  kDefault = 1,
+  // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
+  kAutoBroadcast = 2,
 };
 
 /*!
@@ -129,6 +130,8 @@ class BufferNode : public Node {
    *  elem_offset is guaranteed to be multiple of offset_factor.
    */
   int offset_factor;
+  /*! \brief buffer type */
+  BufferType buffer_type;
   /*! \brief constructor */
   BufferNode() {}
 
@@ -142,6 +145,7 @@ class BufferNode : public Node {
     v->Visit("scope", &scope);
     v->Visit("data_alignment", &data_alignment);
     v->Visit("offset_factor", &offset_factor);
+    v->Visit("buffer_type", &buffer_type);
   }
 
   /*! \return preferred index type for this buffer node */
@@ -159,7 +163,8 @@ class BufferNode : public Node {
                              std::string name,
                              std::string scope,
                              int data_alignment,
-                             int offset_factor);
+                             int offset_factor,
+                             BufferType buffer_type);
 
   static constexpr const char* _type_key = "Buffer";
   TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
index d88f061..e4777b6 100644 (file)
@@ -531,7 +531,8 @@ def decl_buffer(shape,
                 elem_offset=None,
                 scope="",
                 data_alignment=-1,
-                offset_factor=0):
+                offset_factor=0,
+                buffer_type=""):
     """Declare a new symbolic buffer.
 
     Normally buffer is created automatically during lower and build.
@@ -574,11 +575,39 @@ def decl_buffer(shape,
         If 0 is pssed, the alignment will be set to 1.
         if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
 
+    buffer_type: str, optional, {"", "auto_broadcast"}
+        auto_broadcast buffer allows one to implement broadcast computation
+        without considering whether dimension size equals to one.
+        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
+
     Returns
     -------
     buffer : Buffer
         The created buffer
 
+    Example
+    -------
+    Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
+
+    .. code-block:: python
+
+        m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
+        n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
+        o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
+        A = tvm.placeholder((m0, m1, m2), name='A')
+        B = tvm.placeholder((n0, n1, n2), name='B')
+        C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
+        Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast")
+        Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast")
+        s = tvm.create_schedule(C.op)
+        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
+        fadd(a, b, c)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
     Note
     ----
     Buffer data structure reflects the DLTensor structure in dlpack.
@@ -602,7 +631,7 @@ def decl_buffer(shape,
         data = var(name, "handle")
     return _api_internal._Buffer(
         data, dtype, shape, strides, elem_offset, name, scope,
-        data_alignment, offset_factor)
+        data_alignment, offset_factor, buffer_type)
 
 def layout(layout_str):
     """Create a layout node from a string.
index 42d60b8..00ac715 100644 (file)
@@ -207,7 +207,13 @@ TVM_REGISTER_API("Range")
   });
 
 TVM_REGISTER_API("_Buffer")
-.set_body_typed(BufferNode::make);
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    CHECK_EQ(args.size(), 10);
+    auto buffer_type = args[9].operator std::string();
+    BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
+    *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
+                            args[5], args[6], args[7], args[8], type);
+  });
 
 TVM_REGISTER_API("_BufferAccessPtr")
 .set_body_method(&Buffer::access_ptr);
index 6917200..c162233 100644 (file)
@@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
   }
 
   return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
-    data_alignment, offset_factor);
+    data_alignment, offset_factor, kDefault);
 }
 
 void GetBinds(const Array<Tensor>& args,
index 3e06151..573ecff 100644 (file)
@@ -49,7 +49,8 @@ Buffer decl_buffer(Array<Expr> shape,
       Expr(),
       name,
       "",
-      0, 0);
+      0, 0,
+      kDefault);
 }
 
 // Split the given expression w.r.t the add operator
@@ -365,7 +366,8 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
                           n->name + "_slice",
                           n->scope,
                           n->data_alignment,
-                          0);
+                          0,
+                          n->buffer_type);
 }
 
 Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
@@ -405,7 +407,8 @@ Buffer BufferNode::make(Var data,
                         std::string name,
                         std::string scope,
                         int data_alignment,
-                        int offset_factor) {
+                        int offset_factor,
+                        BufferType buffer_type) {
   auto n = make_node<BufferNode>();
   n->data = std::move(data);
   n->dtype = dtype;
@@ -428,6 +431,12 @@ Buffer BufferNode::make(Var data,
   n->elem_offset = std::move(elem_offset);
   n->data_alignment = data_alignment;
   n->offset_factor = offset_factor;
+  n->buffer_type = buffer_type;
+  if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
+    for (size_t i = 0; i < n->shape.size(); ++i) {
+      n->strides.push_back(tvm::var("stride"));
+    }
+  }
   return Buffer(n);
 }
 
index 2822393..d93d088 100644 (file)
@@ -242,6 +242,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       check = IfThenElse::make(Not::make(is_null), check, Stmt());
       init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
     }
+  } else if (buffer->buffer_type == kAutoBroadcast) {
+    Type stype = buffer->DefaultIndexType();
+    Expr stride = make_const(stype, 1);
+    for (size_t i = buffer->shape.size(); i != 0; --i) {
+      size_t k = i - 1;
+      std::ostringstream field_name;
+      field_name << v_strides->name_hint << '[' << k << ']';
+      Expr value = cast(buffer->shape[k].type(),
+                        Load::make(tvm_shape_type, v_strides,
+                                   IntImm::make(Int(32), k), const_true(1)));
+      value = tvm::if_then_else(is_null, stride, value);
+      value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
+      Bind_(buffer->strides[k], value, field_name.str(), true);
+      stride = Simplify(stride * buffer->shape[k]);
+    }
   } else {
     std::ostringstream stride_null_err_msg;
     stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
index a906ee3..8df5fe1 100644 (file)
@@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator {
         store_strides[loop_var_size],
         store->buffer_var->name_hint,
         GetStorageScope(store->buffer_var.get()),
-        0, 0);
+        0, 0, kDefault);
     Buffer src = BufferNode::make(
         Var(load->buffer_var.node_),
         load->type,
@@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator {
         src_elem_offset,
         load->buffer_var->name_hint,
         GetStorageScope(load->buffer_var.get()),
-        0, 0);
+        0, 0, kDefault);
     *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
     CHECK(out->defined()) << "flower function did not return correct stmt";
     return true;
index 215f6d7..ff6b416 100644 (file)
@@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator {
           Var(key.GetName(), Handle()),
           op->type, shape, strides, Expr(),
           key.GetName(), skey.to_string(),
-          align, 0);
+          align, 0, kDefault);
 
       buf_map_[key] = e;
       Stmt body = this->Mutate(op->body);
index e0bb027..bd45eac 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 import tvm
 from tvm.schedule import Buffer
+import numpy as np
 
 def test_buffer():
     m = tvm.var('m')
@@ -108,6 +109,34 @@ def test_buffer_index_merge_mult_mod():
     index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
     assert_simplified_equal(index_simplified, index_direct)
 
+def test_buffer_broadcast():
+    m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
+    n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
+    o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
+
+    A = tvm.placeholder((m0, m1, m2), name='A')
+    B = tvm.placeholder((n0, n1, n2), name='B')
+
+    C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
+
+    Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
+    Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
+    s = tvm.create_schedule(C.op)
+
+    def check():
+        if not tvm.module.enabled("llvm"):
+            return
+        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
+        fadd(a, b, c)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+    check()
+
+
 if __name__ == "__main__":
     test_buffer()
     test_buffer_access_ptr()
@@ -115,3 +144,4 @@ if __name__ == "__main__":
     test_buffer_access_ptr_extent()
     test_buffer_vload()
     test_buffer_index_merge_mult_mod()
+    test_buffer_broadcast()
index ac00e52..667722e 100644 (file)
@@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array<Expr> shape,
   auto data = var(name, Handle());
   auto elem_offset = Expr();
   return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
-                          -1, 0);
+                          -1, 0, kDefault);
 }
 
 /*!