[Relay/TOPI][OP] Add clip and wrap mode support in take (#2858)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 1 Apr 2019 22:40:11 +0000 (15:40 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 1 Apr 2019 22:40:11 +0000 (06:40 +0800)
* Update take

* Add special case for canonical simplify and fix test cases

* Use lower case for wrap and clip

* remove unnecssary lower

* Fix mxnet converter for take

* fix

12 files changed:
3rdparty/HalideIR
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/relay/test_op_level3.py
tests/python/unittest/test_arith_simplify.py
topi/include/topi/transform.h
topi/python/topi/transform.py
topi/src/topi.cc
topi/tests/python/test_topi_transform.py

index 86351c4..55ba177 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90
+Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
index af49382..9f17205 100644 (file)
@@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
 
 struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
   Integer axis;
+  std::string mode;
 
   TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
     TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
         .describe("The axis over which to select values.");
+    TVM_ATTR_FIELD(mode).set_default("clip")
+        .describe("Specify how out-of-bound indices will behave."
+                  "clip - clip to the range (default)"
+                  "wrap - wrap around the indices");
   }
 };
 
index 47ad5d7..8e36801 100644 (file)
@@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
     return _op.tile(inputs[0], **new_attrs)
 
 
+def _mx_take(inputs, attrs):
+    assert len(inputs) == 2
+    mode = attrs.get_str("mode", "clip")
+    if mode == "raise":
+        raise RuntimeError("take doesn't support raise mode")
+    axis = attrs.get_int("axis", 0)
+    return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
+
+
 def _mx_reverse(inputs, attrs):
     assert len(inputs) == 1
     new_attrs = {}
@@ -749,6 +758,7 @@ _convert_map = {
     "_full"         : _mx_full,
     "repeat"        : _mx_repeat,
     "tile"          : _mx_tile,
+    "take"          : _mx_take,
     "reverse"       : _mx_reverse,
     "squeeze"       : _mx_squeeze,
     "broadcast_axis": _mx_broadcast_axis,
index 37aace5..7357304 100644 (file)
@@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
     return _make.reshape_like(data, shape_like)
 
 
-def take(data, indices, axis=None):
+def take(data, indices, axis=None, mode="clip"):
     """Take elements from an array along an axis.
 
     Parameters
@@ -201,12 +201,17 @@ def take(data, indices, axis=None):
         The axis over which to select values. By default,
         the flattened input array is used.
 
+    mode : str, optional
+        Specifies how out-of-bound indices will behave.
+        clip - clip to the range (default)
+        wrap - wrap around the indices
+
     Returns
     -------
     ret : relay.Expr
         The computed result.
     """
-    return _make.take(data, indices, axis)
+    return _make.take(data, indices, axis, mode)
 
 
 def full(fill_value, shape=(), dtype=""):
index a0ea8f2..08b06a2 100644 (file)
@@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
   const auto* param = attrs.as<TakeAttrs>();
   CHECK(param != nullptr);
   if (!param->axis.defined()) {
-    return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
+    return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
   } else {
-    return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
+    return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
   }
 }
 
 Expr MakeTake(Expr data,
               Expr indices,
-              Integer axis) {
+              Integer axis,
+              std::string mode) {
   auto attrs = make_node<TakeAttrs>();
   attrs->axis = std::move(axis);
+  attrs->mode = std::move(mode);
   static const Op& op = Op::Get("take");
   return CallNode::make(op, {data, indices}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_API("relay.op._make.take")
 .set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
+    runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
 });
 
 RELAY_REGISTER_OP("take")
index faccfbf..9d0d594 100644 (file)
@@ -464,7 +464,6 @@ def test_forward_embedding():
     verify((2, 2), (4, 5))
     verify((2, 3, 4), (4, 5))
 
-
 def test_forward_smooth_l1():
     data = mx.sym.var('data')
     mx_sym = mx.sym.smooth_l1(data)
@@ -472,6 +471,26 @@ def test_forward_smooth_l1():
     mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
     verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
 
+def test_forward_take():
+    def verify(shape, indices_src, axis, mode="clip"):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        indices_np = np.array(indices_src, dtype="float32")
+        ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
+        mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x_np, indices_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((2,2), [[[1,0],[0,1]]], 0)
+    verify((2,2), [[[1,0],[0,1]]], 1)
+    verify((4,3,5,6), [[2,1,0,0]], -2)
+    verify((3,4), [-1, 5], 0)
+    verify((3,4), [-1, 5], 0, mode="wrap")
+    verify((3,4), [-1, 5], 1)
+    verify((3,4), [-1, 5], 1, mode="wrap")
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -507,3 +526,4 @@ if __name__ == '__main__':
     test_forward_full()
     test_forward_embedding()
     test_forward_smooth_l1()
+    test_forward_take()
index 10ace54..0cfbcc2 100644 (file)
@@ -243,17 +243,17 @@ def test_take_infer_type():
     verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
 
 def test_take():
-    def verify_take(src_shape, indices_src, axis=None):
+    def verify_take(src_shape, indices_src, axis=None, mode="clip"):
         src_dtype = "float32"
         indices_dtype = "int32"
         indices_src = np.array(indices_src, dtype=indices_dtype)
         x = relay.var("x", relay.TensorType(src_shape, src_dtype))
         indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
-        z = relay.take(x, indices, axis=axis)
+        z = relay.take(x, indices, axis=axis, mode=mode)
 
         func = relay.Function([x, indices], z)
         x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
-        ref_res = np.take(x_data, indices=indices_src, axis=axis)
+        ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
 
         for target, ctx in ctx_list():
             for kind in ["graph", "debug"]:
@@ -269,6 +269,12 @@ def test_take():
     verify_take((2,2), [[[1,0],[0,1]]], 0)
     verify_take((2,2), [[[1,0],[0,1]]], 1)
     verify_take((4,3,5,6), [[2,1,0,0]], -2)
+    verify_take((3,4), [-5, 20])
+    verify_take((3,4), [-5, 20], mode="wrap")
+    verify_take((3,4), [-1, 2], axis=0)
+    verify_take((3,4), [-1, 2], axis=0, mode="wrap")
+    verify_take((3,4), [-1, 2], axis=1)
+    verify_take((3,4), [-1, 2], axis=1, mode="wrap")
 
 
 def test_split_infer_type():
index a327650..6ee3bc6 100644 (file)
@@ -39,6 +39,11 @@ def test_simplify_mod():
     stmt = tvm.ir_pass.CanonicalSimplify(body)
     diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
     assert diff.value == 0
+    # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
+    index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
+    assert index != j
+    index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
+    assert index == j
     # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
     index = tvm.ir_pass.CanonicalSimplify(
         (j + n * 32) % 16, {j: tvm.Range(0, 6)})
index 464bd6f..bbe1a31 100644 (file)
@@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
 */
 inline Tensor take(const Tensor& a,
                    const Tensor& indices,
+                   std::string mode = "clip",
                    std::string name = "tensor",
                    std::string tag = kInjective) {
   Array<Expr> a_shape = a->shape;
-  Array<Expr> out_shape;
-  for (size_t j = 0; j < indices->shape.size(); ++j) {
-    out_shape.push_back(indices->shape[j]);
+  Array<Expr> out_shape = indices->shape;
+  Expr a_size = 1;
+  for (size_t i = 0; i < a_shape.size(); ++i) {
+    a_size = a_size * a_shape[i];
   }
 
-  return compute(
+  if (mode == "clip") {
+    return compute(
         out_shape, [&](const Array<Var>& out_index) {
-          Array<Expr> indices_position;
-          for (size_t j = 0; j < indices->shape.size(); ++j) {
-            indices_position.push_back(out_index[j]);
-          }
-          return a(UnravelIndex(indices(indices_position), a_shape));
+          auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
+          return a(UnravelIndex(idx, a_shape));
+        }, name, tag);
+  } else {  // mode == "wrap"
+    return compute(
+        out_shape, [&](const Array<Var>& out_index) {
+          auto idx = (indices(out_index) % a_size + a_size) % a_size;
+          return a(UnravelIndex(idx, a_shape));
         }, name, tag);
+  }
 }
 
 /*!
@@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
 inline Tensor take(const Tensor& a,
                    const Tensor& indices,
                    int axis,
+                   std::string mode = "clip",
                    std::string name = "tensor",
                    std::string tag = kInjective) {
   if (axis < 0) {
     axis += static_cast<int>(a->shape.size());
   }
+  CHECK_GE(axis, 0) << "axis out of bounds";
   CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
+  auto axis_dim = a->shape[axis];
 
   int indices_len = static_cast<int>(indices->shape.size());
   Array<Expr> out_shape;
@@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a,
       out_shape.push_back(a->shape[i]);
     }
   }
-  return compute(
+  if (mode == "clip") {
+    return compute(
+        out_shape, [&](const Array<Var>& out_index) {
+          Array<Expr> indices_position;
+          for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+            indices_position.push_back(out_index[j]);
+          }
+          Array<Expr> real_indices;
+          for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          auto idx = tvm::min(tvm::max(0, indices(indices_position)),
+                              axis_dim - 1);
+          real_indices.push_back(idx);
+          for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          return a(real_indices);
+        }, name, tag);
+  } else {  // mode == "wrap"
+    return compute(
         out_shape, [&](const Array<Var>& out_index) {
           Array<Expr> indices_position;
           for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
@@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a,
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
-          real_indices.push_back(indices(indices_position));
+          auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
+          real_indices.push_back(idx);
           for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
             real_indices.push_back(out_index[j]);
           }
           return a(real_indices);
         }, name, tag);
+  }
 }
 
 /*!
index 2c109cd..e674b9e 100644 (file)
@@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
     return cpp.split(ary, indices_or_sections, axis)
 
 
-def take(a, indices, axis=None):
+def take(a, indices, axis=None, mode="clip"):
     """Take elements from an array along an axis.
 
     Parameters
@@ -243,13 +243,18 @@ def take(a, indices, axis=None):
         The axis over which to select values. By default,
         the flattened input array is used.
 
+    mode : str, optional
+        Specifies how out-of-bound indices will behave.
+        clip - clip to the range (default)
+        wrap - wrap around the indices
+
     Returns
     -------
     ret : tvm.Tensor
     """
     if axis is None:
-        return cpp.take(a, indices)
-    return cpp.take(a, indices, int(axis))
+        return cpp.take(a, indices, mode)
+    return cpp.take(a, indices, int(axis), mode)
 
 
 def gather_nd(a, indices):
index aed2eab..1df73d8 100644 (file)
@@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")
 
 TVM_REGISTER_GLOBAL("topi.take")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  if (args.size() == 2) {
-    *rv = take(args[0], args[1]);
+  if (args.size() == 3) {
+    std::string mode = args[2];
+    *rv = take(args[0], args[1], mode);
   } else {
     int axis = args[2];
-    *rv = take(args[0], args[1], axis);
+    std::string mode = args[3];
+    *rv = take(args[0], args[1], axis, mode);
   }
   });
 
index 59c1090..b56df9f 100644 (file)
@@ -232,16 +232,16 @@ def verify_flip(in_shape, axis):
     for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
-def verify_take(src_shape, indices_src, axis=None):
+def verify_take(src_shape, indices_src, axis=None, mode="clip"):
     src_dtype = "float32"
     indices_dtype = "int32"
     indices_src = np.array(indices_src, dtype=indices_dtype)
     A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
     indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
     if axis is None:
-        out_tensor = topi.take(a=A, indices=indices)
+        out_tensor = topi.take(a=A, indices=indices, mode=mode)
     else:
-        out_tensor = topi.take(a=A, indices=indices, axis=axis)
+        out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -259,9 +259,9 @@ def verify_take(src_shape, indices_src, axis=None):
         data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
 
         if axis is None:
-            out_npys = np.take(data_npy, indices_src)
+            out_npys = np.take(data_npy, indices_src, mode=mode)
         else:
-            out_npys = np.take(data_npy, indices_src, axis=axis)
+            out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
         data_nd = tvm.nd.array(data_npy, ctx)
         indices_nd = tvm.nd.array(indices_src, ctx)
         out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
@@ -498,6 +498,12 @@ def test_take():
     verify_take((2,2), [[[1,0],[0,1]]], 0)
     verify_take((2,2), [[[1,0],[0,1]]], 1)
     verify_take((4,3,5,6), [[2,1,0,0]], -2)
+    verify_take((3,4), [-5, 20])
+    verify_take((3,4), [-5, 20], mode="wrap")
+    verify_take((3,4), [-1, 2], axis=0)
+    verify_take((3,4), [-1, 2], axis=0, mode="wrap")
+    verify_take((3,4), [-1, 2], axis=1)
+    verify_take((3,4), [-1, 2], axis=1, mode="wrap")
 
 def test_gather_nd():
     for indices_dtype in ['int32', 'float32']: