Dynamic Tile Op (#5983)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Tue, 7 Jul 2020 18:50:35 +0000 (11:50 -0700)
committerGitHub <noreply@github.com>
Tue, 7 Jul 2020 18:50:35 +0000 (11:50 -0700)
* first working dynamic tile passes first test

* add dyn tile to dynamic_to_static

* fix cpplintt

* respond to review comments. Thanks @siju-samuel

* make dynamic tile compatible with numpy API

python/tvm/relay/op/dyn/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/dyn/tensor/transform.cc
src/relay/transforms/dynamic_to_static.cc
tests/python/relay/dyn/test_dynamic_op_level3.py
tests/python/relay/test_pass_dynamic_to_static.py
topi/include/topi/transform.h

index 81c6e5e..8279b12 100644 (file)
 """Backend compiler related feature registration"""
 # pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
 from __future__ import absolute_import
+
+from tvm.runtime import convert
 from tvm.te.hybrid import script
 from .. import op as _reg
 
 _reg.register_injective_schedule("dyn.reshape")
+_reg.register_broadcast_schedule("dyn.tile")
 
 @script
 def _reshape_shape_func_input_data(data, newshape, ndim):
@@ -81,3 +84,40 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
 @_reg.register_shape_func("dyn.reshape", True)
 def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
     return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
+
+
+@script
+def _tile_shape_func(data, reps, ndim, tndim, rndim):
+    out = output_tensor((tndim,), "int64")
+
+    if ndim == rndim:
+        for i in const_range(tndim):
+            out[i] = int64(data.shape[i] * reps[i])
+    elif ndim > rndim:
+        ngap = ndim - rndim
+        for i in const_range(ndim):
+            if i < ngap:
+                out[i] = int64(data.shape[i])
+            else:
+                out[i] = int64(data.shape[i] * reps[i - ngap])
+    else:
+        rgap = rndim - ndim
+        for i in const_range(rndim):
+            if i < rgap:
+                out[i] = int64(reps[i])
+            else:
+                out[i] = int64(reps[i] * data.shape[i - rgap])
+    return out
+
+
+@_reg.register_shape_func("dyn.tile", True)
+def tile_shape_func(attrs, inputs, _):
+    """
+    Shape function for dyn.tile op.
+    """
+    reps = inputs[1]
+    ndim = len(inputs[0].shape)
+    rndim = inputs[1].shape[0].value
+    tndim = ndim if ndim > rndim else rndim
+    return [_tile_shape_func(inputs[0], reps, convert(ndim),
+                             convert(tndim), convert(rndim))]
index 188cd5c..173db64 100644 (file)
@@ -496,7 +496,7 @@ def tile(data, reps):
     data : relay.Expr
         The input data to the operator.
 
-    reps : tuple of int
+    reps : tuple of int or relay.Expr
         The number of times repeating the tensor data.
 
     Returns
@@ -524,7 +524,8 @@ def tile(data, reps):
     data is promoted to be d-dimensional by prepending new axes.
     If data.ndim >=  d, reps is promoted to a.ndim by pre-pending 1's to it.
     """
-
+    if isinstance(reps, Expr):
+        return _dyn_make.tile(data, reps)
     return _make.tile(data, reps)
 
 
index 18eaa67..0b8a156 100644 (file)
@@ -29,6 +29,8 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/runtime/registry.h>
 
+#include <vector>
+
 namespace tvm {
 namespace relay {
 namespace dyn {
@@ -128,6 +130,71 @@ RELAY_REGISTER_OP("dyn.reshape")
     .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// tile operator
+// TVM_REGISTER_NODE_TYPE(TileAttrs);
+
+bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+             const TypeReporter& reporter) {
+  // `types` contains: [data, reps, result]
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* reps = types[1].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "tile: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+  if (reps == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "tile: expect input type to be TensorType but get " << types[1];
+    return false;
+  }
+  const IntImmNode* reps_shape = reps->shape[0].as<IntImmNode>();
+  CHECK(reps_shape) << "Parameter reps must have static shape";
+  const size_t ndim = data->shape.size();
+  const size_t rndim = reps_shape->value;
+  size_t tndim = (ndim > rndim) ? ndim : rndim;
+  std::vector<IndexExpr> oshape;
+  oshape.reserve(tndim);
+  for (size_t i = 0; i < tndim; ++i) {
+    oshape.emplace_back(Any());
+  }
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
+  return true;
+}
+
+Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
+  CHECK_EQ(inputs.size(), 2);
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  size_t rndim = inputs[1]->shape[0].as<IntImmNode>()->value;
+  return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)};
+}
+
+Expr MakeTile(Expr data, Expr reps) {
+  auto attrs = make_object<TileAttrs>();
+  static const Op& op = Op::Get("dyn.tile");
+  return Call(op, {data, reps}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.tile").set_body_typed(MakeTile);
+
+RELAY_REGISTER_OP("dyn.tile")
+    .describe(R"code(Repeat the whole array multiple times.
+
+- **data**: The input data to the operator.
+- **reps**: The number of times to repeat the operator.
+
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .set_attrs_type<TileAttrs>()
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("reps", "Tensor", "The number of times to repeat the input on each axis.")
+    .set_support_level(3)
+    .add_type_rel("DynamicTile", TileRel)
+    .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 }  // namespace dyn
 }  // namespace relay
 }  // namespace tvm
index 7b3f195..d09230a 100644 (file)
@@ -32,7 +32,8 @@ namespace relay {
 
 class DynamicToStaticMutator : public MixedModeMutator {
  public:
-  DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {}
+  DynamicToStaticMutator()
+      : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}
 
  private:
   Expr Rewrite_(const CallNode* pre, const Expr& post) override {
@@ -46,6 +47,14 @@ class DynamicToStaticMutator : public MixedModeMutator {
         static const Op& reshape = Op::Get("reshape");
         return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
       }
+    } else if (call_node->op == dyn_tile_op_) {
+      if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
+        auto attrs = make_object<TileAttrs>();
+        CHECK_EQ(reps->data->ndim, 1);
+        attrs->reps = ToVector(reps->data);
+        static const Op& op = Op::Get("tile");
+        return Call(op, {call_node->args[0]}, Attrs(attrs), {});
+      }
     }
     return post;
   }
@@ -58,6 +67,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
   }
 
   const Op& dyn_reshape_op_;
+  const Op& dyn_tile_op_;
 };
 
 Expr DynamicToStatic(Function f, IRModule m) {
index 29168b6..2f473c9 100644 (file)
@@ -70,6 +70,21 @@ def test_dyn_shape_reshape():
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
 
+def test_dyn_tile():
+    def verify_tile(dshape, reps):
+        x = relay.var("x", relay.TensorType(dshape, "float32"))
+        r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
+        z = relay.tile(x, r)
+
+        func = relay.Function([x, r], z)
+        x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
+        ref_res = np.tile(x_data, reps=reps)
+        verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
+    verify_tile((2, 3, 4), (3, 2, 1))
+    verify_tile((2, 3, 4), (1, 2))
+    verify_tile((2, 3), (3, 2, 1))
+
 if __name__ == "__main__":
     test_dyn_reshape()
     test_dyn_shape_reshape()
+    test_dyn_tile()
index 052d95c..3415ce0 100644 (file)
@@ -108,8 +108,30 @@ def test_dynamic_to_static_quad_reshape():
     verify_reshape((2, 3, 4), (8, 3))
     verify_reshape((4, 7), (2, 7, 2))
 
+def test_dynamic_to_static_tile():
+    def verify_tile(shape, reps, oshape):
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        y = relay.var("y", relay.TensorType(reps, "float32"))
+        z = relay.tile(x, relay.shape_of(y))
+        func = run_infer_type(relay.Function([x, y], z))
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("tile")
+        assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
+
+        x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
+        y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32")
+        ref_res = np.tile(x_data, reps)
+        verify_func(func2, [x_data, y_data], ref_res)
+
+    verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
+    verify_tile((4, 7), (4, 2), (16, 14))
+
 if __name__=="__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()
     test_dynamic_to_static_quad_reshape()
+    test_dynamic_to_static_tile()
 
index 0da1d71..b5fc02a 100644 (file)
@@ -1017,6 +1017,43 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
 }
 
 /*!
+ * \brief Creates an operation to tile elements of an array
+ *
+ * \param x The input tensor
+ * \param new_shape The shape of the output after tiling
+ * \param rdim The rank of the reps, provided by caller
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the tile operation
+ */
+inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
+                       std::string name = "T_tile", std::string tag = kBroadcast) {
+  size_t ndim = x->shape.size();
+  if (is_empty_shape(new_shape)) {
+    return compute(
+        new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
+  } else {
+    return compute(
+        new_shape,
+        [&](const Array<Var>& indices) {
+          Array<PrimExpr> idx;
+          if (ndim >= rdim) {
+            for (size_t i = 0; i < ndim; ++i) {
+              idx.push_back(indexmod(indices[i], x->shape[i]));
+            }
+          } else {
+            for (size_t i = 0; i < ndim; ++i) {
+              idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
+            }
+          }
+          return x(idx);
+        },
+        name, tag);
+  }
+}
+
+/*!
  * \brief Gather values along given axis from given indices.
  *
  * \param data The input data to the operator.