[relay][op] Add shape func to tile (#4441)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 5 Dec 2019 06:16:37 +0000 (22:16 -0800)
committerYao Wang <kevinthesunwy@gmail.com>
Thu, 5 Dec 2019 06:16:37 +0000 (22:16 -0800)
* [relay][op] Add shape func to tile

* retrigger ci

* check dynamic axes

* retrigger ci

python/tvm/relay/op/_transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_any.py

index c89ac33..de708fb 100644 (file)
@@ -501,3 +501,35 @@ def reshape_like_shape_func(attrs, inputs, _):
     Shape function for reshape_like op.
     """
     return [_reshape_like_shape_func(inputs[1])]
+
+@script
+def _tile_shape_func(data, reps, ndim, tndim, rndim):
+    out = output_tensor((tndim,), "int64")
+
+    if ndim == rndim:
+        for i in const_range(tndim):
+            out[i] = data[i] * int64(reps[i])
+    elif ndim > rndim:
+        ngap = ndim - rndim
+        for i in const_range(ndim):
+            if i < ngap:
+                out[i] = data[i]
+            else:
+                out[i] = data[i] * int64(reps[i - ngap])
+    else:
+        rgap = rndim - ndim
+        for i in const_range(rndim):
+            if i < rgap:
+                out[i] = int64(reps[i])
+            else:
+                out[i] = int64(reps[i]) * data[i - rgap]
+    return out
+
+@_reg.register_shape_func("tile", False)
+def tile_shape_func(attrs, inputs, _):
+    reps = get_const_tuple(attrs.reps)
+    ndim = inputs[0].shape[0].value
+    rndim = len(reps)
+    tndim = ndim if ndim > rndim else rndim
+    return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
+                             convert(tndim), convert(rndim))]
index 3a58a4b..23d49cd 100644 (file)
@@ -1393,28 +1393,39 @@ bool TileRel(const Array<Type>& types,
   reps_shape.reserve(tndim);
   if (ndim == rndim) {
     for (size_t i = 0; i < tndim; ++i) {
-        data_shape.emplace_back(data->shape[i]);
-        reps_shape.emplace_back(reps[i]);
+      data_shape.emplace_back(data->shape[i]);
+      reps_shape.emplace_back(reps[i]);
     }
   } else if (ndim > rndim) {
-    for (size_t i = 0; i < ndim; ++i)
-        data_shape.emplace_back(data->shape[i]);
-    for (size_t i = 0; i < (ndim - rndim); ++i)
-        reps_shape.emplace_back(1);
-    for (size_t i = 0; i < rndim; ++i)
-        reps_shape.emplace_back(reps[i]);
+    for (size_t i = 0; i < ndim; ++i) {
+      data_shape.emplace_back(data->shape[i]);
+    }
+    for (size_t i = 0; i < (ndim - rndim); ++i) {
+      reps_shape.emplace_back(1);
+    }
+    for (size_t i = 0; i < rndim; ++i) {
+      reps_shape.emplace_back(reps[i]);
+    }
   } else {
-    for (size_t i = 0; i < rndim; ++i)
-        reps_shape.emplace_back(reps[i]);
-    for (size_t i = 0; i < (rndim - ndim); ++i)
-        data_shape.emplace_back(1);
-    for (size_t i = 0; i < ndim; ++i)
-        data_shape.emplace_back(data->shape[i]);
+    for (size_t i = 0; i < rndim; ++i) {
+      reps_shape.emplace_back(reps[i]);
+    }
+    for (size_t i = 0; i < (rndim - ndim); ++i) {
+      data_shape.emplace_back(1);
+    }
+    for (size_t i = 0; i < ndim; ++i) {
+      data_shape.emplace_back(data->shape[i]);
+    }
   }
   std::vector<IndexExpr> oshape;
   oshape.reserve(tndim);
   for (size_t i = 0; i < tndim; ++i) {
-    oshape.emplace_back(data_shape[i] * reps_shape[i]);
+    // Save Any if it is dynamic shape
+    if (!data_shape[i].as<IntImm>()) {
+      oshape.emplace_back(Any::make());
+    } else {
+      oshape.emplace_back(data_shape[i] * reps_shape[i]);
+    }
   }
   reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
   return true;
index 9e0208f..d7246da 100644 (file)
@@ -193,6 +193,25 @@ def test_any_take():
     verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
     verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
 
+def verify_any_tile(dshape, reps, np_dshape, np_reps):
+    mod = relay.Module()
+    x = relay.var("x", shape=dshape, dtype="float32")
+    y = relay.tile(x, reps=reps)
+    mod["main"] = relay.Function([x], y)
+    x_data = np.random.uniform(size=np_dshape).astype("float32")
+    ref_res = np.tile(x_data, reps=np_reps)
+
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        res = ex.evaluate()(x_data)
+        tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5)
+
+def test_any_tile():
+    verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1))
+    verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2))
+    verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1))
+    verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,))
+
 def test_any_shape_of():
     x = relay.var('x', shape=any_dims(2), dtype='float32')
     y = relay.shape_of(x)
@@ -586,6 +605,7 @@ if __name__ == "__main__":
     test_any_concat()
     test_any_reshape()
     test_any_take()
+    test_any_tile()
     test_any_shape_of()
     test_any_reduce()
     test_any_layout_transform()