Extend TensorComputeOp to allow scalar inputs (#2606). (#3300)
authorJessica Davies <46725573+jdavies-huawei@users.noreply.github.com>
Sat, 22 Jun 2019 04:22:54 +0000 (00:22 -0400)
committerziheng <ziheng@apache.org>
Sat, 22 Jun 2019 04:22:54 +0000 (21:22 -0700)
include/tvm/operation.h
include/tvm/tensor_intrin.h
python/tvm/api.py
python/tvm/tensor_intrin.py
src/lang/tensor.cc
src/op/tensor_compute_op.cc
src/schedule/schedule_dataflow_rewrite.cc
tests/python/unittest/test_lang_schedule.py

index 15a8c12..38dc39b 100644 (file)
@@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
   Array<Tensor> inputs;
   /*! \brief region of input tensors */
   Array<Region> input_regions;
+  /*! \brief scalar expression inputs */
+  Array<Expr> scalar_inputs;
   /*! \brief constructor */
   TensorComputeOpNode() {}
   // override functions
@@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
     v->Visit("intrin", &intrin);
     v->Visit("inputs", &inputs);
     v->Visit("input_regions", &input_regions);
+    v->Visit("scalar_inputs", &scalar_inputs);
   }
   static Operation make(std::string name,
                         std::string tag,
@@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode {
                         int schedulable_ndim,
                         TensorIntrin intrin,
                         Array<Tensor> tensors,
-                        Array<Region> regions);
+                        Array<Region> regions,
+                        Array<Expr> scalar_inputs);
 
   static constexpr const char* _type_key = "TensorComputeOp";
   TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
index e61ce66..b5ca6eb 100644 (file)
@@ -67,6 +67,11 @@ class TensorIntrinNode : public Node {
    *  When it is a constant, it means we can only take data in that shape.
    */
   Array<Buffer> buffers;
+  /*! \brief List of scalar variables, used in body. These placeholders
+   *  will be bound to expressions passed in when the TensorIntrin is called
+   * from a TensorComputeOp.
+   */
+  Array<Var> scalar_params;
   /*! \brief The normal statement to execute the intrinsic */
   Stmt body;
   /*!
@@ -87,6 +92,7 @@ class TensorIntrinNode : public Node {
     v->Visit("op", &op);
     v->Visit("inputs", &inputs);
     v->Visit("buffers", &buffers);
+    v->Visit("scalar_params", &scalar_params);
     v->Visit("body", &body);
     v->Visit("reduce_init", &reduce_init);
     v->Visit("reduce_update", &reduce_update);
@@ -96,6 +102,7 @@ class TensorIntrinNode : public Node {
                                    Operation op,
                                    Array<Tensor> inputs,
                                    Array<Buffer> buffers,
+                                   Array<Var> scalar_params,
                                    Stmt body,
                                    Stmt reduce_init,
                                    Stmt reduce_update);
@@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node {
   Array<Tensor> tensors;
   /*! \brief regions of input tensors */
   Array<Region> regions;
+
+
   /*!
    * \brief IterVar on each reduction axis, if the
    * intrin will use the reduce axis
    */
   Array<IterVar> reduce_axis;
 
+  /*! \brief scalar expression inputs */
+  Array<Expr> scalar_inputs;
+
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("intrin", &intrin);
     v->Visit("tensors", &tensors);
     v->Visit("regions", &regions);
     v->Visit("reduce_axis", &reduce_axis);
+    v->Visit("scalar_inputs", &scalar_inputs);
   }
   static TensorIntrinCall make(TensorIntrin intrin,
                                Array<Tensor> tensors,
                                Array<Region> regions,
-                               Array<IterVar> reduce_axis);
+                               Array<IterVar> reduce_axis,
+                               Array<Expr> scalar_inputs);
 
   static constexpr const char* _type_key = "TensorIntrinCall";
   TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
index 66fa4fa..d88f061 100644 (file)
@@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
                                                  out_ndim,
                                                  body.intrin,
                                                  body.tensors,
-                                                 body.regions)
+                                                 body.regions,
+                                                 body.scalar_inputs)
     else:
         if not isinstance(body, (list, tuple)):
             body = [body]
index f97e6b7..2ef7a4b 100644 (file)
@@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
     decl_tensor_intrin: Construct a TensorIntrin
     """
     def __call__(self, *args, **kwargs):
-        tensors = [x.tensor for x in args]
-        regions = [_get_region(x) for x in args]
+        tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
+        scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
+        regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
         reduce_axis = []
         if "reduce_axis" in kwargs:
             reduce_axis = kwargs["reduce_axis"]
             if not isinstance(reduce_axis, (list, tuple)):
                 reduce_axis = [reduce_axis]
             reduce_axis = _api.convert(reduce_axis)
-        return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
+        if scalar_inputs:
+            scalar_inputs = _api.convert(scalar_inputs)
+        return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
 
 def decl_tensor_intrin(op,
                        fcompute,
                        name="tensor_intrin",
-                       binds=None):
+                       binds=None, scalar_params=None):
     """Declare a tensor intrinsic function.
 
     Parameters
@@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
         requirement of the function. By default, a new compact buffer is created
         for each tensor in the argument.
 
+    scalar_params: a list of variables used by op, whose values will be passed
+                   as scalar_inputs when the tensor intrinsic is called.
+
     Returns
     -------
     intrin: TensorIntrin
@@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
                                 offset_factor=cfg.offset_factor))
         binds_list.append(buf)
 
-    body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
+    if scalar_params:
+        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
+    else:
+        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
+        scalar_params = []
     if isinstance(body, (_expr.Expr, _stmt.Stmt)):
         body = [body]
     body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
     if len(body) < 3:
         body += [None] * (3 - len(body))
     return _api_internal._TensorIntrin(
-        name, op, inputs, binds_list, *body)
+        name, op, inputs, binds_list, scalar_params, *body)
index bab7cf6..d885d71 100644 (file)
@@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
                                     Operation op,
                                     Array<Tensor> inputs,
                                     Array<Buffer> buffers,
+                                    Array<Var> scalar_params,
                                     Stmt body,
                                     Stmt reduce_init,
                                     Stmt reduce_update) {
@@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
   n->op = std::move(op);
   n->inputs = std::move(inputs);
   n->buffers = std::move(buffers);
+  n->scalar_params = std::move(scalar_params);
   n->body = std::move(body);
   n->reduce_init = std::move(reduce_init);
   n->reduce_update = std::move(reduce_update);
@@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
 TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
                                             Array<Tensor> tensors,
                                             Array<Region> regions,
-                                            Array<IterVar> reduce_axis) {
+                                            Array<IterVar> reduce_axis,
+                                            Array<Expr> scalar_inputs) {
   auto n = make_node<TensorIntrinCallNode>();
   n->intrin = std::move(intrin);
   n->tensors = std::move(tensors);
   n->regions = std::move(regions);
   n->reduce_axis = std::move(reduce_axis);
+  n->scalar_inputs = std::move(scalar_inputs);
   return TensorIntrinCall(n);
 }
 
index ed768c2..09e8af7 100644 (file)
@@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name,
                                     int schedulable_ndim,
                                     TensorIntrin intrin,
                                     Array<Tensor> tensors,
-                                    Array<Region> regions) {
+                                    Array<Region> regions,
+                                    Array<Expr> scalar_inputs) {
   auto n = make_node<TensorComputeOpNode>();
   n->name = std::move(name);
   n->tag = std::move(tag);
@@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name,
   n->intrin = std::move(intrin);
   n->inputs = std::move(tensors);
   n->input_regions = std::move(regions);
+  n->scalar_inputs = std::move(scalar_inputs);
   return Operation(n);
 }
 
@@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide(
   std::unordered_map<const Variable*, Expr> vmap;
   ir::ArgBinder binder(&vmap);
 
+  // Map the expressions passed in the call to the TensorIntrin, to the placeholder
+  // variables
+  Array<Expr> user_expr = this->scalar_inputs;
+  Array<Var> scalar_params = this->intrin->scalar_params;
+  Array<Expr> sp_expr;
+  for (auto sp : scalar_params) {
+    Expr esp = sp;
+    sp_expr.push_back(esp);
+  }
+  CHECK_EQ(sp_expr.size(), user_expr.size());
+  // TODO(jdavies-huawei): what name should be used here?
+  binder.BindArray(sp_expr, user_expr, this->name);
+
   size_t tloc = stage->leaf_iter_vars.size();
   ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
 
index cbb5ae3..bb42754 100644 (file)
@@ -410,10 +410,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
     new_regions.push_back(region);
   }
 
+  Array<Expr> new_scalar_inputs;
+  for (Expr old_input : tensor_op->scalar_inputs) {
+    new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
+  }
+
   Operation cache_op = TensorComputeOpNode::make(
       tensor_op->name + "." + scope, tensor_op->tag, new_axis,
       tensor_op->reduce_axis, tensor_op->schedulable_ndim,
-      tensor_op->intrin, tensor_op->inputs, new_regions);
+      tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
 
   // axis will be used in generating compute op
   Array<IterVar> compute_axis = tensor_op->axis;
index fb33e2b..6279215 100644 (file)
@@ -209,11 +209,43 @@ def test_tensor_intrin():
     assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
     assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
 
+def test_tensor_intrin_scalar_params():
+    n = tvm.var("n")
+    x = tvm.placeholder((n,), name='x')
+    v = tvm.var("v")
+    w = tvm.var("w")
+    z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')
+
+    def intrin_func(ins, outs, sp):
+        assert(isinstance(ins[0], tvm.schedule.Buffer))
+        assert(ins[0].shape[0] == n)
+        assert(sp[0] == v)
+        assert(sp[1] == w)
+        return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
+
+    with tvm.build_config(offset_factor=1):
+      intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
+    assert intrin.op == z.op
+    assert intrin.reduce_init is None
+    assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
+    assert(intrin.buffers[0].shape[0] == n)
+    assert tuple(intrin.scalar_params) == tuple((v, w))
+
+    A = tvm.placeholder((10,10), name='A')
+    # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
+    C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
+    s = tvm.create_schedule(C.op)
+    stmt = tvm.lower(s, [A, C], simple_mode=True)
+    assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
+    assert len(stmt.body.body.body.value.args) == 5
+    assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
+    assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
 
 if __name__ == "__main__":
     test_singleton()
     test_pragma()
     test_tensor_intrin()
+    test_tensor_intrin_scalar_params()
     test_rfactor()
     test_schedule_create()
     test_reorder()